]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Revert contains_pragma_comment function changes
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def _test_wip(self) -> None:
350         source, expected = read_data("wip")
351         sys.settrace(tracefunc)
352         mode = replace(
353             DEFAULT_MODE,
354             experimental_string_processing=False,
355             target_versions={black.TargetVersion.PY38},
356         )
357         actual = fs(source, mode=mode)
358         sys.settrace(None)
359         self.assertFormatEqual(expected, actual)
360         black.assert_equivalent(source, actual)
361         black.assert_stable(source, actual, black.FileMode())
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_function_trailing_comma(self) -> None:
365         source, expected = read_data("function_trailing_comma")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_equivalent(source, actual)
369         black.assert_stable(source, actual, DEFAULT_MODE)
370
371     @unittest.expectedFailure
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_trailing_comma_optional_parens_stability1(self) -> None:
374         source, _expected = read_data("trailing_comma_optional_parens1")
375         actual = fs(source)
376         black.assert_stable(source, actual, DEFAULT_MODE)
377
378     @unittest.expectedFailure
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_trailing_comma_optional_parens_stability2(self) -> None:
381         source, _expected = read_data("trailing_comma_optional_parens2")
382         actual = fs(source)
383         black.assert_stable(source, actual, DEFAULT_MODE)
384
385     @unittest.expectedFailure
386     @patch("black.dump_to_file", dump_to_stderr)
387     def test_trailing_comma_optional_parens_stability3(self) -> None:
388         source, _expected = read_data("trailing_comma_optional_parens3")
389         actual = fs(source)
390         black.assert_stable(source, actual, DEFAULT_MODE)
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_expression(self) -> None:
394         source, expected = read_data("expression")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, DEFAULT_MODE)
399
400     @patch("black.dump_to_file", dump_to_stderr)
401     def test_pep_572(self) -> None:
402         source, expected = read_data("pep_572")
403         actual = fs(source)
404         self.assertFormatEqual(expected, actual)
405         black.assert_stable(source, actual, DEFAULT_MODE)
406         if sys.version_info >= (3, 8):
407             black.assert_equivalent(source, actual)
408
409     def test_pep_572_version_detection(self) -> None:
410         source, _ = read_data("pep_572")
411         root = black.lib2to3_parse(source)
412         features = black.get_features_used(root)
413         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
414         versions = black.detect_target_versions(root)
415         self.assertIn(black.TargetVersion.PY38, versions)
416
417     def test_expression_ff(self) -> None:
418         source, expected = read_data("expression")
419         tmp_file = Path(black.dump_to_file(source))
420         try:
421             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
422             with open(tmp_file, encoding="utf8") as f:
423                 actual = f.read()
424         finally:
425             os.unlink(tmp_file)
426         self.assertFormatEqual(expected, actual)
427         with patch("black.dump_to_file", dump_to_stderr):
428             black.assert_equivalent(source, actual)
429             black.assert_stable(source, actual, DEFAULT_MODE)
430
431     def test_expression_diff(self) -> None:
432         source, _ = read_data("expression.py")
433         expected, _ = read_data("expression.diff")
434         tmp_file = Path(black.dump_to_file(source))
435         diff_header = re.compile(
436             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
437             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
438         )
439         try:
440             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
441             self.assertEqual(result.exit_code, 0)
442         finally:
443             os.unlink(tmp_file)
444         actual = result.output
445         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
446         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
447         if expected != actual:
448             dump = black.dump_to_file(actual)
449             msg = (
450                 "Expected diff isn't equal to the actual. If you made changes to"
451                 " expression.py and this is an anticipated difference, overwrite"
452                 f" tests/data/expression.diff with {dump}"
453             )
454             self.assertEqual(expected, actual, msg)
455
456     def test_expression_diff_with_color(self) -> None:
457         source, _ = read_data("expression.py")
458         expected, _ = read_data("expression.diff")
459         tmp_file = Path(black.dump_to_file(source))
460         try:
461             result = BlackRunner().invoke(
462                 black.main, ["--diff", "--color", str(tmp_file)]
463             )
464         finally:
465             os.unlink(tmp_file)
466         actual = result.output
467         # We check the contents of the diff in `test_expression_diff`. All
468         # we need to check here is that color codes exist in the result.
469         self.assertIn("\033[1;37m", actual)
470         self.assertIn("\033[36m", actual)
471         self.assertIn("\033[32m", actual)
472         self.assertIn("\033[31m", actual)
473         self.assertIn("\033[0m", actual)
474
475     @patch("black.dump_to_file", dump_to_stderr)
476     def test_fstring(self) -> None:
477         source, expected = read_data("fstring")
478         actual = fs(source)
479         self.assertFormatEqual(expected, actual)
480         black.assert_equivalent(source, actual)
481         black.assert_stable(source, actual, DEFAULT_MODE)
482
483     @patch("black.dump_to_file", dump_to_stderr)
484     def test_pep_570(self) -> None:
485         source, expected = read_data("pep_570")
486         actual = fs(source)
487         self.assertFormatEqual(expected, actual)
488         black.assert_stable(source, actual, DEFAULT_MODE)
489         if sys.version_info >= (3, 8):
490             black.assert_equivalent(source, actual)
491
492     def test_detect_pos_only_arguments(self) -> None:
493         source, _ = read_data("pep_570")
494         root = black.lib2to3_parse(source)
495         features = black.get_features_used(root)
496         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
497         versions = black.detect_target_versions(root)
498         self.assertIn(black.TargetVersion.PY38, versions)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_string_quotes(self) -> None:
502         source, expected = read_data("string_quotes")
503         actual = fs(source)
504         self.assertFormatEqual(expected, actual)
505         black.assert_equivalent(source, actual)
506         black.assert_stable(source, actual, DEFAULT_MODE)
507         mode = replace(DEFAULT_MODE, string_normalization=False)
508         not_normalized = fs(source, mode=mode)
509         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
510         black.assert_equivalent(source, not_normalized)
511         black.assert_stable(source, not_normalized, mode=mode)
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_docstring(self) -> None:
515         source, expected = read_data("docstring")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, DEFAULT_MODE)
520         mode = replace(DEFAULT_MODE, string_normalization=False)
521         not_normalized = fs(source, mode=mode)
522         self.assertFormatEqual(expected, not_normalized)
523         black.assert_equivalent(source, not_normalized)
524         black.assert_stable(source, not_normalized, mode=mode)
525
526     def test_long_strings(self) -> None:
527         """Tests for splitting long strings."""
528         source, expected = read_data("long_strings")
529         actual = fs(source)
530         self.assertFormatEqual(expected, actual)
531         black.assert_equivalent(source, actual)
532         black.assert_stable(source, actual, DEFAULT_MODE)
533
534     def test_long_strings_flag_disabled(self) -> None:
535         """Tests for turning off the string processing logic."""
536         source, expected = read_data("long_strings_flag_disabled")
537         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
538         actual = fs(source, mode=mode)
539         self.assertFormatEqual(expected, actual)
540         black.assert_stable(expected, actual, mode)
541
542     @patch("black.dump_to_file", dump_to_stderr)
543     def test_long_strings__edge_case(self) -> None:
544         """Edge-case tests for splitting long strings."""
545         source, expected = read_data("long_strings__edge_case")
546         actual = fs(source)
547         self.assertFormatEqual(expected, actual)
548         black.assert_equivalent(source, actual)
549         black.assert_stable(source, actual, DEFAULT_MODE)
550
551     @patch("black.dump_to_file", dump_to_stderr)
552     def test_long_strings__regression(self) -> None:
553         """Regression tests for splitting long strings."""
554         source, expected = read_data("long_strings__regression")
555         actual = fs(source)
556         self.assertFormatEqual(expected, actual)
557         black.assert_equivalent(source, actual)
558         black.assert_stable(source, actual, DEFAULT_MODE)
559
560     @patch("black.dump_to_file", dump_to_stderr)
561     def test_slices(self) -> None:
562         source, expected = read_data("slices")
563         actual = fs(source)
564         self.assertFormatEqual(expected, actual)
565         black.assert_equivalent(source, actual)
566         black.assert_stable(source, actual, DEFAULT_MODE)
567
568     @patch("black.dump_to_file", dump_to_stderr)
569     def test_percent_precedence(self) -> None:
570         source, expected = read_data("percent_precedence")
571         actual = fs(source)
572         self.assertFormatEqual(expected, actual)
573         black.assert_equivalent(source, actual)
574         black.assert_stable(source, actual, DEFAULT_MODE)
575
576     @patch("black.dump_to_file", dump_to_stderr)
577     def test_comments(self) -> None:
578         source, expected = read_data("comments")
579         actual = fs(source)
580         self.assertFormatEqual(expected, actual)
581         black.assert_equivalent(source, actual)
582         black.assert_stable(source, actual, DEFAULT_MODE)
583
584     @patch("black.dump_to_file", dump_to_stderr)
585     def test_comments2(self) -> None:
586         source, expected = read_data("comments2")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         black.assert_equivalent(source, actual)
590         black.assert_stable(source, actual, DEFAULT_MODE)
591
592     @patch("black.dump_to_file", dump_to_stderr)
593     def test_comments3(self) -> None:
594         source, expected = read_data("comments3")
595         actual = fs(source)
596         self.assertFormatEqual(expected, actual)
597         black.assert_equivalent(source, actual)
598         black.assert_stable(source, actual, DEFAULT_MODE)
599
600     @patch("black.dump_to_file", dump_to_stderr)
601     def test_comments4(self) -> None:
602         source, expected = read_data("comments4")
603         actual = fs(source)
604         self.assertFormatEqual(expected, actual)
605         black.assert_equivalent(source, actual)
606         black.assert_stable(source, actual, DEFAULT_MODE)
607
608     @patch("black.dump_to_file", dump_to_stderr)
609     def test_comments5(self) -> None:
610         source, expected = read_data("comments5")
611         actual = fs(source)
612         self.assertFormatEqual(expected, actual)
613         black.assert_equivalent(source, actual)
614         black.assert_stable(source, actual, DEFAULT_MODE)
615
616     @patch("black.dump_to_file", dump_to_stderr)
617     def test_comments6(self) -> None:
618         source, expected = read_data("comments6")
619         actual = fs(source)
620         self.assertFormatEqual(expected, actual)
621         black.assert_equivalent(source, actual)
622         black.assert_stable(source, actual, DEFAULT_MODE)
623
624     @patch("black.dump_to_file", dump_to_stderr)
625     def test_comments7(self) -> None:
626         source, expected = read_data("comments7")
627         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
628         actual = fs(source, mode=mode)
629         self.assertFormatEqual(expected, actual)
630         black.assert_equivalent(source, actual)
631         black.assert_stable(source, actual, DEFAULT_MODE)
632
633     @patch("black.dump_to_file", dump_to_stderr)
634     def test_comment_after_escaped_newline(self) -> None:
635         source, expected = read_data("comment_after_escaped_newline")
636         actual = fs(source)
637         self.assertFormatEqual(expected, actual)
638         black.assert_equivalent(source, actual)
639         black.assert_stable(source, actual, DEFAULT_MODE)
640
641     @patch("black.dump_to_file", dump_to_stderr)
642     def test_cantfit(self) -> None:
643         source, expected = read_data("cantfit")
644         actual = fs(source)
645         self.assertFormatEqual(expected, actual)
646         black.assert_equivalent(source, actual)
647         black.assert_stable(source, actual, DEFAULT_MODE)
648
649     @patch("black.dump_to_file", dump_to_stderr)
650     def test_import_spacing(self) -> None:
651         source, expected = read_data("import_spacing")
652         actual = fs(source)
653         self.assertFormatEqual(expected, actual)
654         black.assert_equivalent(source, actual)
655         black.assert_stable(source, actual, DEFAULT_MODE)
656
657     @patch("black.dump_to_file", dump_to_stderr)
658     def test_composition(self) -> None:
659         source, expected = read_data("composition")
660         actual = fs(source)
661         self.assertFormatEqual(expected, actual)
662         black.assert_equivalent(source, actual)
663         black.assert_stable(source, actual, DEFAULT_MODE)
664
665     @patch("black.dump_to_file", dump_to_stderr)
666     def test_composition_no_trailing_comma(self) -> None:
667         source, expected = read_data("composition_no_trailing_comma")
668         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
669         actual = fs(source, mode=mode)
670         self.assertFormatEqual(expected, actual)
671         black.assert_equivalent(source, actual)
672         black.assert_stable(source, actual, DEFAULT_MODE)
673
674     @patch("black.dump_to_file", dump_to_stderr)
675     def test_empty_lines(self) -> None:
676         source, expected = read_data("empty_lines")
677         actual = fs(source)
678         self.assertFormatEqual(expected, actual)
679         black.assert_equivalent(source, actual)
680         black.assert_stable(source, actual, DEFAULT_MODE)
681
682     @patch("black.dump_to_file", dump_to_stderr)
683     def test_remove_parens(self) -> None:
684         source, expected = read_data("remove_parens")
685         actual = fs(source)
686         self.assertFormatEqual(expected, actual)
687         black.assert_equivalent(source, actual)
688         black.assert_stable(source, actual, DEFAULT_MODE)
689
690     @patch("black.dump_to_file", dump_to_stderr)
691     def test_string_prefixes(self) -> None:
692         source, expected = read_data("string_prefixes")
693         actual = fs(source)
694         self.assertFormatEqual(expected, actual)
695         black.assert_equivalent(source, actual)
696         black.assert_stable(source, actual, DEFAULT_MODE)
697
698     @patch("black.dump_to_file", dump_to_stderr)
699     def test_numeric_literals(self) -> None:
700         source, expected = read_data("numeric_literals")
701         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
702         actual = fs(source, mode=mode)
703         self.assertFormatEqual(expected, actual)
704         black.assert_equivalent(source, actual)
705         black.assert_stable(source, actual, mode)
706
707     @patch("black.dump_to_file", dump_to_stderr)
708     def test_numeric_literals_ignoring_underscores(self) -> None:
709         source, expected = read_data("numeric_literals_skip_underscores")
710         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
711         actual = fs(source, mode=mode)
712         self.assertFormatEqual(expected, actual)
713         black.assert_equivalent(source, actual)
714         black.assert_stable(source, actual, mode)
715
716     @patch("black.dump_to_file", dump_to_stderr)
717     def test_numeric_literals_py2(self) -> None:
718         source, expected = read_data("numeric_literals_py2")
719         actual = fs(source)
720         self.assertFormatEqual(expected, actual)
721         black.assert_stable(source, actual, DEFAULT_MODE)
722
723     @patch("black.dump_to_file", dump_to_stderr)
724     def test_python2(self) -> None:
725         source, expected = read_data("python2")
726         actual = fs(source)
727         self.assertFormatEqual(expected, actual)
728         black.assert_equivalent(source, actual)
729         black.assert_stable(source, actual, DEFAULT_MODE)
730
731     @patch("black.dump_to_file", dump_to_stderr)
732     def test_python2_print_function(self) -> None:
733         source, expected = read_data("python2_print_function")
734         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
735         actual = fs(source, mode=mode)
736         self.assertFormatEqual(expected, actual)
737         black.assert_equivalent(source, actual)
738         black.assert_stable(source, actual, mode)
739
740     @patch("black.dump_to_file", dump_to_stderr)
741     def test_python2_unicode_literals(self) -> None:
742         source, expected = read_data("python2_unicode_literals")
743         actual = fs(source)
744         self.assertFormatEqual(expected, actual)
745         black.assert_equivalent(source, actual)
746         black.assert_stable(source, actual, DEFAULT_MODE)
747
748     @patch("black.dump_to_file", dump_to_stderr)
749     def test_stub(self) -> None:
750         mode = replace(DEFAULT_MODE, is_pyi=True)
751         source, expected = read_data("stub.pyi")
752         actual = fs(source, mode=mode)
753         self.assertFormatEqual(expected, actual)
754         black.assert_stable(source, actual, mode)
755
756     @patch("black.dump_to_file", dump_to_stderr)
757     def test_async_as_identifier(self) -> None:
758         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
759         source, expected = read_data("async_as_identifier")
760         actual = fs(source)
761         self.assertFormatEqual(expected, actual)
762         major, minor = sys.version_info[:2]
763         if major < 3 or (major <= 3 and minor < 7):
764             black.assert_equivalent(source, actual)
765         black.assert_stable(source, actual, DEFAULT_MODE)
766         # ensure black can parse this when the target is 3.6
767         self.invokeBlack([str(source_path), "--target-version", "py36"])
768         # but not on 3.7, because async/await is no longer an identifier
769         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
770
771     @patch("black.dump_to_file", dump_to_stderr)
772     def test_python37(self) -> None:
773         source_path = (THIS_DIR / "data" / "python37.py").resolve()
774         source, expected = read_data("python37")
775         actual = fs(source)
776         self.assertFormatEqual(expected, actual)
777         major, minor = sys.version_info[:2]
778         if major > 3 or (major == 3 and minor >= 7):
779             black.assert_equivalent(source, actual)
780         black.assert_stable(source, actual, DEFAULT_MODE)
781         # ensure black can parse this when the target is 3.7
782         self.invokeBlack([str(source_path), "--target-version", "py37"])
783         # but not on 3.6, because we use async as a reserved keyword
784         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
785
786     @patch("black.dump_to_file", dump_to_stderr)
787     def test_python38(self) -> None:
788         source, expected = read_data("python38")
789         actual = fs(source)
790         self.assertFormatEqual(expected, actual)
791         major, minor = sys.version_info[:2]
792         if major > 3 or (major == 3 and minor >= 8):
793             black.assert_equivalent(source, actual)
794         black.assert_stable(source, actual, DEFAULT_MODE)
795
796     @patch("black.dump_to_file", dump_to_stderr)
797     def test_fmtonoff(self) -> None:
798         source, expected = read_data("fmtonoff")
799         actual = fs(source)
800         self.assertFormatEqual(expected, actual)
801         black.assert_equivalent(source, actual)
802         black.assert_stable(source, actual, DEFAULT_MODE)
803
804     @patch("black.dump_to_file", dump_to_stderr)
805     def test_fmtonoff2(self) -> None:
806         source, expected = read_data("fmtonoff2")
807         actual = fs(source)
808         self.assertFormatEqual(expected, actual)
809         black.assert_equivalent(source, actual)
810         black.assert_stable(source, actual, DEFAULT_MODE)
811
812     @patch("black.dump_to_file", dump_to_stderr)
813     def test_fmtonoff3(self) -> None:
814         source, expected = read_data("fmtonoff3")
815         actual = fs(source)
816         self.assertFormatEqual(expected, actual)
817         black.assert_equivalent(source, actual)
818         black.assert_stable(source, actual, DEFAULT_MODE)
819
820     @patch("black.dump_to_file", dump_to_stderr)
821     def test_fmtonoff4(self) -> None:
822         source, expected = read_data("fmtonoff4")
823         actual = fs(source)
824         self.assertFormatEqual(expected, actual)
825         black.assert_equivalent(source, actual)
826         black.assert_stable(source, actual, DEFAULT_MODE)
827
828     @patch("black.dump_to_file", dump_to_stderr)
829     def test_remove_empty_parentheses_after_class(self) -> None:
830         source, expected = read_data("class_blank_parentheses")
831         actual = fs(source)
832         self.assertFormatEqual(expected, actual)
833         black.assert_equivalent(source, actual)
834         black.assert_stable(source, actual, DEFAULT_MODE)
835
836     @patch("black.dump_to_file", dump_to_stderr)
837     def test_new_line_between_class_and_code(self) -> None:
838         source, expected = read_data("class_methods_new_line")
839         actual = fs(source)
840         self.assertFormatEqual(expected, actual)
841         black.assert_equivalent(source, actual)
842         black.assert_stable(source, actual, DEFAULT_MODE)
843
844     @patch("black.dump_to_file", dump_to_stderr)
845     def test_bracket_match(self) -> None:
846         source, expected = read_data("bracketmatch")
847         actual = fs(source)
848         self.assertFormatEqual(expected, actual)
849         black.assert_equivalent(source, actual)
850         black.assert_stable(source, actual, DEFAULT_MODE)
851
852     @patch("black.dump_to_file", dump_to_stderr)
853     def test_tuple_assign(self) -> None:
854         source, expected = read_data("tupleassign")
855         actual = fs(source)
856         self.assertFormatEqual(expected, actual)
857         black.assert_equivalent(source, actual)
858         black.assert_stable(source, actual, DEFAULT_MODE)
859
860     @patch("black.dump_to_file", dump_to_stderr)
861     def test_beginning_backslash(self) -> None:
862         source, expected = read_data("beginning_backslash")
863         actual = fs(source)
864         self.assertFormatEqual(expected, actual)
865         black.assert_equivalent(source, actual)
866         black.assert_stable(source, actual, DEFAULT_MODE)
867
868     def test_tab_comment_indentation(self) -> None:
869         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
870         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
871         self.assertFormatEqual(contents_spc, fs(contents_spc))
872         self.assertFormatEqual(contents_spc, fs(contents_tab))
873
874         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
875         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
876         self.assertFormatEqual(contents_spc, fs(contents_spc))
877         self.assertFormatEqual(contents_spc, fs(contents_tab))
878
879         # mixed tabs and spaces (valid Python 2 code)
880         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
881         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
882         self.assertFormatEqual(contents_spc, fs(contents_spc))
883         self.assertFormatEqual(contents_spc, fs(contents_tab))
884
885         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
886         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
887         self.assertFormatEqual(contents_spc, fs(contents_spc))
888         self.assertFormatEqual(contents_spc, fs(contents_tab))
889
890     def test_report_verbose(self) -> None:
891         report = black.Report(verbose=True)
892         out_lines = []
893         err_lines = []
894
895         def out(msg: str, **kwargs: Any) -> None:
896             out_lines.append(msg)
897
898         def err(msg: str, **kwargs: Any) -> None:
899             err_lines.append(msg)
900
901         with patch("black.out", out), patch("black.err", err):
902             report.done(Path("f1"), black.Changed.NO)
903             self.assertEqual(len(out_lines), 1)
904             self.assertEqual(len(err_lines), 0)
905             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
906             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
907             self.assertEqual(report.return_code, 0)
908             report.done(Path("f2"), black.Changed.YES)
909             self.assertEqual(len(out_lines), 2)
910             self.assertEqual(len(err_lines), 0)
911             self.assertEqual(out_lines[-1], "reformatted f2")
912             self.assertEqual(
913                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
914             )
915             report.done(Path("f3"), black.Changed.CACHED)
916             self.assertEqual(len(out_lines), 3)
917             self.assertEqual(len(err_lines), 0)
918             self.assertEqual(
919                 out_lines[-1], "f3 wasn't modified on disk since last run."
920             )
921             self.assertEqual(
922                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
923             )
924             self.assertEqual(report.return_code, 0)
925             report.check = True
926             self.assertEqual(report.return_code, 1)
927             report.check = False
928             report.failed(Path("e1"), "boom")
929             self.assertEqual(len(out_lines), 3)
930             self.assertEqual(len(err_lines), 1)
931             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
932             self.assertEqual(
933                 unstyle(str(report)),
934                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
935                 " reformat.",
936             )
937             self.assertEqual(report.return_code, 123)
938             report.done(Path("f3"), black.Changed.YES)
939             self.assertEqual(len(out_lines), 4)
940             self.assertEqual(len(err_lines), 1)
941             self.assertEqual(out_lines[-1], "reformatted f3")
942             self.assertEqual(
943                 unstyle(str(report)),
944                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
945                 " reformat.",
946             )
947             self.assertEqual(report.return_code, 123)
948             report.failed(Path("e2"), "boom")
949             self.assertEqual(len(out_lines), 4)
950             self.assertEqual(len(err_lines), 2)
951             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
952             self.assertEqual(
953                 unstyle(str(report)),
954                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
955                 " reformat.",
956             )
957             self.assertEqual(report.return_code, 123)
958             report.path_ignored(Path("wat"), "no match")
959             self.assertEqual(len(out_lines), 5)
960             self.assertEqual(len(err_lines), 2)
961             self.assertEqual(out_lines[-1], "wat ignored: no match")
962             self.assertEqual(
963                 unstyle(str(report)),
964                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
965                 " reformat.",
966             )
967             self.assertEqual(report.return_code, 123)
968             report.done(Path("f4"), black.Changed.NO)
969             self.assertEqual(len(out_lines), 6)
970             self.assertEqual(len(err_lines), 2)
971             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
972             self.assertEqual(
973                 unstyle(str(report)),
974                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
975                 " reformat.",
976             )
977             self.assertEqual(report.return_code, 123)
978             report.check = True
979             self.assertEqual(
980                 unstyle(str(report)),
981                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
982                 " would fail to reformat.",
983             )
984             report.check = False
985             report.diff = True
986             self.assertEqual(
987                 unstyle(str(report)),
988                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
989                 " would fail to reformat.",
990             )
991
992     def test_report_quiet(self) -> None:
993         report = black.Report(quiet=True)
994         out_lines = []
995         err_lines = []
996
997         def out(msg: str, **kwargs: Any) -> None:
998             out_lines.append(msg)
999
1000         def err(msg: str, **kwargs: Any) -> None:
1001             err_lines.append(msg)
1002
1003         with patch("black.out", out), patch("black.err", err):
1004             report.done(Path("f1"), black.Changed.NO)
1005             self.assertEqual(len(out_lines), 0)
1006             self.assertEqual(len(err_lines), 0)
1007             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1008             self.assertEqual(report.return_code, 0)
1009             report.done(Path("f2"), black.Changed.YES)
1010             self.assertEqual(len(out_lines), 0)
1011             self.assertEqual(len(err_lines), 0)
1012             self.assertEqual(
1013                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1014             )
1015             report.done(Path("f3"), black.Changed.CACHED)
1016             self.assertEqual(len(out_lines), 0)
1017             self.assertEqual(len(err_lines), 0)
1018             self.assertEqual(
1019                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1020             )
1021             self.assertEqual(report.return_code, 0)
1022             report.check = True
1023             self.assertEqual(report.return_code, 1)
1024             report.check = False
1025             report.failed(Path("e1"), "boom")
1026             self.assertEqual(len(out_lines), 0)
1027             self.assertEqual(len(err_lines), 1)
1028             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1029             self.assertEqual(
1030                 unstyle(str(report)),
1031                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1032                 " reformat.",
1033             )
1034             self.assertEqual(report.return_code, 123)
1035             report.done(Path("f3"), black.Changed.YES)
1036             self.assertEqual(len(out_lines), 0)
1037             self.assertEqual(len(err_lines), 1)
1038             self.assertEqual(
1039                 unstyle(str(report)),
1040                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1041                 " reformat.",
1042             )
1043             self.assertEqual(report.return_code, 123)
1044             report.failed(Path("e2"), "boom")
1045             self.assertEqual(len(out_lines), 0)
1046             self.assertEqual(len(err_lines), 2)
1047             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1048             self.assertEqual(
1049                 unstyle(str(report)),
1050                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1051                 " reformat.",
1052             )
1053             self.assertEqual(report.return_code, 123)
1054             report.path_ignored(Path("wat"), "no match")
1055             self.assertEqual(len(out_lines), 0)
1056             self.assertEqual(len(err_lines), 2)
1057             self.assertEqual(
1058                 unstyle(str(report)),
1059                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1060                 " reformat.",
1061             )
1062             self.assertEqual(report.return_code, 123)
1063             report.done(Path("f4"), black.Changed.NO)
1064             self.assertEqual(len(out_lines), 0)
1065             self.assertEqual(len(err_lines), 2)
1066             self.assertEqual(
1067                 unstyle(str(report)),
1068                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1069                 " reformat.",
1070             )
1071             self.assertEqual(report.return_code, 123)
1072             report.check = True
1073             self.assertEqual(
1074                 unstyle(str(report)),
1075                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1076                 " would fail to reformat.",
1077             )
1078             report.check = False
1079             report.diff = True
1080             self.assertEqual(
1081                 unstyle(str(report)),
1082                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1083                 " would fail to reformat.",
1084             )
1085
1086     def test_report_normal(self) -> None:
1087         report = black.Report()
1088         out_lines = []
1089         err_lines = []
1090
1091         def out(msg: str, **kwargs: Any) -> None:
1092             out_lines.append(msg)
1093
1094         def err(msg: str, **kwargs: Any) -> None:
1095             err_lines.append(msg)
1096
1097         with patch("black.out", out), patch("black.err", err):
1098             report.done(Path("f1"), black.Changed.NO)
1099             self.assertEqual(len(out_lines), 0)
1100             self.assertEqual(len(err_lines), 0)
1101             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1102             self.assertEqual(report.return_code, 0)
1103             report.done(Path("f2"), black.Changed.YES)
1104             self.assertEqual(len(out_lines), 1)
1105             self.assertEqual(len(err_lines), 0)
1106             self.assertEqual(out_lines[-1], "reformatted f2")
1107             self.assertEqual(
1108                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1109             )
1110             report.done(Path("f3"), black.Changed.CACHED)
1111             self.assertEqual(len(out_lines), 1)
1112             self.assertEqual(len(err_lines), 0)
1113             self.assertEqual(out_lines[-1], "reformatted f2")
1114             self.assertEqual(
1115                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1116             )
1117             self.assertEqual(report.return_code, 0)
1118             report.check = True
1119             self.assertEqual(report.return_code, 1)
1120             report.check = False
1121             report.failed(Path("e1"), "boom")
1122             self.assertEqual(len(out_lines), 1)
1123             self.assertEqual(len(err_lines), 1)
1124             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1125             self.assertEqual(
1126                 unstyle(str(report)),
1127                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1128                 " reformat.",
1129             )
1130             self.assertEqual(report.return_code, 123)
1131             report.done(Path("f3"), black.Changed.YES)
1132             self.assertEqual(len(out_lines), 2)
1133             self.assertEqual(len(err_lines), 1)
1134             self.assertEqual(out_lines[-1], "reformatted f3")
1135             self.assertEqual(
1136                 unstyle(str(report)),
1137                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1138                 " reformat.",
1139             )
1140             self.assertEqual(report.return_code, 123)
1141             report.failed(Path("e2"), "boom")
1142             self.assertEqual(len(out_lines), 2)
1143             self.assertEqual(len(err_lines), 2)
1144             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1145             self.assertEqual(
1146                 unstyle(str(report)),
1147                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1148                 " reformat.",
1149             )
1150             self.assertEqual(report.return_code, 123)
1151             report.path_ignored(Path("wat"), "no match")
1152             self.assertEqual(len(out_lines), 2)
1153             self.assertEqual(len(err_lines), 2)
1154             self.assertEqual(
1155                 unstyle(str(report)),
1156                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1157                 " reformat.",
1158             )
1159             self.assertEqual(report.return_code, 123)
1160             report.done(Path("f4"), black.Changed.NO)
1161             self.assertEqual(len(out_lines), 2)
1162             self.assertEqual(len(err_lines), 2)
1163             self.assertEqual(
1164                 unstyle(str(report)),
1165                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1166                 " reformat.",
1167             )
1168             self.assertEqual(report.return_code, 123)
1169             report.check = True
1170             self.assertEqual(
1171                 unstyle(str(report)),
1172                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1173                 " would fail to reformat.",
1174             )
1175             report.check = False
1176             report.diff = True
1177             self.assertEqual(
1178                 unstyle(str(report)),
1179                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1180                 " would fail to reformat.",
1181             )
1182
1183     def test_lib2to3_parse(self) -> None:
1184         with self.assertRaises(black.InvalidInput):
1185             black.lib2to3_parse("invalid syntax")
1186
1187         straddling = "x + y"
1188         black.lib2to3_parse(straddling)
1189         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1190         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1191         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1192
1193         py2_only = "print x"
1194         black.lib2to3_parse(py2_only)
1195         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1196         with self.assertRaises(black.InvalidInput):
1197             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1198         with self.assertRaises(black.InvalidInput):
1199             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1200
1201         py3_only = "exec(x, end=y)"
1202         black.lib2to3_parse(py3_only)
1203         with self.assertRaises(black.InvalidInput):
1204             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1205         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1206         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1207
1208     def test_get_features_used(self) -> None:
1209         node = black.lib2to3_parse("def f(*, arg): ...\n")
1210         self.assertEqual(black.get_features_used(node), set())
1211         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1212         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1213         node = black.lib2to3_parse("f(*arg,)\n")
1214         self.assertEqual(
1215             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1216         )
1217         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1218         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1219         node = black.lib2to3_parse("123_456\n")
1220         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1221         node = black.lib2to3_parse("123456\n")
1222         self.assertEqual(black.get_features_used(node), set())
1223         source, expected = read_data("function")
1224         node = black.lib2to3_parse(source)
1225         expected_features = {
1226             Feature.TRAILING_COMMA_IN_CALL,
1227             Feature.TRAILING_COMMA_IN_DEF,
1228             Feature.F_STRINGS,
1229         }
1230         self.assertEqual(black.get_features_used(node), expected_features)
1231         node = black.lib2to3_parse(expected)
1232         self.assertEqual(black.get_features_used(node), expected_features)
1233         source, expected = read_data("expression")
1234         node = black.lib2to3_parse(source)
1235         self.assertEqual(black.get_features_used(node), set())
1236         node = black.lib2to3_parse(expected)
1237         self.assertEqual(black.get_features_used(node), set())
1238
1239     def test_get_future_imports(self) -> None:
1240         node = black.lib2to3_parse("\n")
1241         self.assertEqual(set(), black.get_future_imports(node))
1242         node = black.lib2to3_parse("from __future__ import black\n")
1243         self.assertEqual({"black"}, black.get_future_imports(node))
1244         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1245         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1246         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1247         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1248         node = black.lib2to3_parse(
1249             "from __future__ import multiple\nfrom __future__ import imports\n"
1250         )
1251         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1252         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1253         self.assertEqual({"black"}, black.get_future_imports(node))
1254         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1255         self.assertEqual({"black"}, black.get_future_imports(node))
1256         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1257         self.assertEqual(set(), black.get_future_imports(node))
1258         node = black.lib2to3_parse("from some.module import black\n")
1259         self.assertEqual(set(), black.get_future_imports(node))
1260         node = black.lib2to3_parse(
1261             "from __future__ import unicode_literals as _unicode_literals"
1262         )
1263         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1264         node = black.lib2to3_parse(
1265             "from __future__ import unicode_literals as _lol, print"
1266         )
1267         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1268
1269     def test_debug_visitor(self) -> None:
1270         source, _ = read_data("debug_visitor.py")
1271         expected, _ = read_data("debug_visitor.out")
1272         out_lines = []
1273         err_lines = []
1274
1275         def out(msg: str, **kwargs: Any) -> None:
1276             out_lines.append(msg)
1277
1278         def err(msg: str, **kwargs: Any) -> None:
1279             err_lines.append(msg)
1280
1281         with patch("black.out", out), patch("black.err", err):
1282             black.DebugVisitor.show(source)
1283         actual = "\n".join(out_lines) + "\n"
1284         log_name = ""
1285         if expected != actual:
1286             log_name = black.dump_to_file(*out_lines)
1287         self.assertEqual(
1288             expected,
1289             actual,
1290             f"AST print out is different. Actual version dumped to {log_name}",
1291         )
1292
1293     def test_format_file_contents(self) -> None:
1294         empty = ""
1295         mode = DEFAULT_MODE
1296         with self.assertRaises(black.NothingChanged):
1297             black.format_file_contents(empty, mode=mode, fast=False)
1298         just_nl = "\n"
1299         with self.assertRaises(black.NothingChanged):
1300             black.format_file_contents(just_nl, mode=mode, fast=False)
1301         same = "j = [1, 2, 3]\n"
1302         with self.assertRaises(black.NothingChanged):
1303             black.format_file_contents(same, mode=mode, fast=False)
1304         different = "j = [1,2,3]"
1305         expected = same
1306         actual = black.format_file_contents(different, mode=mode, fast=False)
1307         self.assertEqual(expected, actual)
1308         invalid = "return if you can"
1309         with self.assertRaises(black.InvalidInput) as e:
1310             black.format_file_contents(invalid, mode=mode, fast=False)
1311         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1312
1313     def test_endmarker(self) -> None:
1314         n = black.lib2to3_parse("\n")
1315         self.assertEqual(n.type, black.syms.file_input)
1316         self.assertEqual(len(n.children), 1)
1317         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1318
1319     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1320     def test_assertFormatEqual(self) -> None:
1321         out_lines = []
1322         err_lines = []
1323
1324         def out(msg: str, **kwargs: Any) -> None:
1325             out_lines.append(msg)
1326
1327         def err(msg: str, **kwargs: Any) -> None:
1328             err_lines.append(msg)
1329
1330         with patch("black.out", out), patch("black.err", err):
1331             with self.assertRaises(AssertionError):
1332                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1333
1334         out_str = "".join(out_lines)
1335         self.assertTrue("Expected tree:" in out_str)
1336         self.assertTrue("Actual tree:" in out_str)
1337         self.assertEqual("".join(err_lines), "")
1338
1339     def test_cache_broken_file(self) -> None:
1340         mode = DEFAULT_MODE
1341         with cache_dir() as workspace:
1342             cache_file = black.get_cache_file(mode)
1343             with cache_file.open("w") as fobj:
1344                 fobj.write("this is not a pickle")
1345             self.assertEqual(black.read_cache(mode), {})
1346             src = (workspace / "test.py").resolve()
1347             with src.open("w") as fobj:
1348                 fobj.write("print('hello')")
1349             self.invokeBlack([str(src)])
1350             cache = black.read_cache(mode)
1351             self.assertIn(src, cache)
1352
1353     def test_cache_single_file_already_cached(self) -> None:
1354         mode = DEFAULT_MODE
1355         with cache_dir() as workspace:
1356             src = (workspace / "test.py").resolve()
1357             with src.open("w") as fobj:
1358                 fobj.write("print('hello')")
1359             black.write_cache({}, [src], mode)
1360             self.invokeBlack([str(src)])
1361             with src.open("r") as fobj:
1362                 self.assertEqual(fobj.read(), "print('hello')")
1363
1364     @event_loop()
1365     def test_cache_multiple_files(self) -> None:
1366         mode = DEFAULT_MODE
1367         with cache_dir() as workspace, patch(
1368             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1369         ):
1370             one = (workspace / "one.py").resolve()
1371             with one.open("w") as fobj:
1372                 fobj.write("print('hello')")
1373             two = (workspace / "two.py").resolve()
1374             with two.open("w") as fobj:
1375                 fobj.write("print('hello')")
1376             black.write_cache({}, [one], mode)
1377             self.invokeBlack([str(workspace)])
1378             with one.open("r") as fobj:
1379                 self.assertEqual(fobj.read(), "print('hello')")
1380             with two.open("r") as fobj:
1381                 self.assertEqual(fobj.read(), 'print("hello")\n')
1382             cache = black.read_cache(mode)
1383             self.assertIn(one, cache)
1384             self.assertIn(two, cache)
1385
1386     def test_no_cache_when_writeback_diff(self) -> None:
1387         mode = DEFAULT_MODE
1388         with cache_dir() as workspace:
1389             src = (workspace / "test.py").resolve()
1390             with src.open("w") as fobj:
1391                 fobj.write("print('hello')")
1392             self.invokeBlack([str(src), "--diff"])
1393             cache_file = black.get_cache_file(mode)
1394             self.assertFalse(cache_file.exists())
1395
1396     def test_no_cache_when_stdin(self) -> None:
1397         mode = DEFAULT_MODE
1398         with cache_dir():
1399             result = CliRunner().invoke(
1400                 black.main, ["-"], input=BytesIO(b"print('hello')")
1401             )
1402             self.assertEqual(result.exit_code, 0)
1403             cache_file = black.get_cache_file(mode)
1404             self.assertFalse(cache_file.exists())
1405
1406     def test_read_cache_no_cachefile(self) -> None:
1407         mode = DEFAULT_MODE
1408         with cache_dir():
1409             self.assertEqual(black.read_cache(mode), {})
1410
1411     def test_write_cache_read_cache(self) -> None:
1412         mode = DEFAULT_MODE
1413         with cache_dir() as workspace:
1414             src = (workspace / "test.py").resolve()
1415             src.touch()
1416             black.write_cache({}, [src], mode)
1417             cache = black.read_cache(mode)
1418             self.assertIn(src, cache)
1419             self.assertEqual(cache[src], black.get_cache_info(src))
1420
1421     def test_filter_cached(self) -> None:
1422         with TemporaryDirectory() as workspace:
1423             path = Path(workspace)
1424             uncached = (path / "uncached").resolve()
1425             cached = (path / "cached").resolve()
1426             cached_but_changed = (path / "changed").resolve()
1427             uncached.touch()
1428             cached.touch()
1429             cached_but_changed.touch()
1430             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1431             todo, done = black.filter_cached(
1432                 cache, {uncached, cached, cached_but_changed}
1433             )
1434             self.assertEqual(todo, {uncached, cached_but_changed})
1435             self.assertEqual(done, {cached})
1436
1437     def test_write_cache_creates_directory_if_needed(self) -> None:
1438         mode = DEFAULT_MODE
1439         with cache_dir(exists=False) as workspace:
1440             self.assertFalse(workspace.exists())
1441             black.write_cache({}, [], mode)
1442             self.assertTrue(workspace.exists())
1443
1444     @event_loop()
1445     def test_failed_formatting_does_not_get_cached(self) -> None:
1446         mode = DEFAULT_MODE
1447         with cache_dir() as workspace, patch(
1448             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1449         ):
1450             failing = (workspace / "failing.py").resolve()
1451             with failing.open("w") as fobj:
1452                 fobj.write("not actually python")
1453             clean = (workspace / "clean.py").resolve()
1454             with clean.open("w") as fobj:
1455                 fobj.write('print("hello")\n')
1456             self.invokeBlack([str(workspace)], exit_code=123)
1457             cache = black.read_cache(mode)
1458             self.assertNotIn(failing, cache)
1459             self.assertIn(clean, cache)
1460
1461     def test_write_cache_write_fail(self) -> None:
1462         mode = DEFAULT_MODE
1463         with cache_dir(), patch.object(Path, "open") as mock:
1464             mock.side_effect = OSError
1465             black.write_cache({}, [], mode)
1466
1467     @event_loop()
1468     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1469     def test_works_in_mono_process_only_environment(self) -> None:
1470         with cache_dir() as workspace:
1471             for f in [
1472                 (workspace / "one.py").resolve(),
1473                 (workspace / "two.py").resolve(),
1474             ]:
1475                 f.write_text('print("hello")\n')
1476             self.invokeBlack([str(workspace)])
1477
1478     @event_loop()
1479     def test_check_diff_use_together(self) -> None:
1480         with cache_dir():
1481             # Files which will be reformatted.
1482             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1483             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1484             # Files which will not be reformatted.
1485             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1486             self.invokeBlack([str(src2), "--diff", "--check"])
1487             # Multi file command.
1488             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1489
1490     def test_no_files(self) -> None:
1491         with cache_dir():
1492             # Without an argument, black exits with error code 0.
1493             self.invokeBlack([])
1494
1495     def test_broken_symlink(self) -> None:
1496         with cache_dir() as workspace:
1497             symlink = workspace / "broken_link.py"
1498             try:
1499                 symlink.symlink_to("nonexistent.py")
1500             except OSError as e:
1501                 self.skipTest(f"Can't create symlinks: {e}")
1502             self.invokeBlack([str(workspace.resolve())])
1503
1504     def test_read_cache_line_lengths(self) -> None:
1505         mode = DEFAULT_MODE
1506         short_mode = replace(DEFAULT_MODE, line_length=1)
1507         with cache_dir() as workspace:
1508             path = (workspace / "file.py").resolve()
1509             path.touch()
1510             black.write_cache({}, [path], mode)
1511             one = black.read_cache(mode)
1512             self.assertIn(path, one)
1513             two = black.read_cache(short_mode)
1514             self.assertNotIn(path, two)
1515
1516     def test_tricky_unicode_symbols(self) -> None:
1517         source, expected = read_data("tricky_unicode_symbols")
1518         actual = fs(source)
1519         self.assertFormatEqual(expected, actual)
1520         black.assert_equivalent(source, actual)
1521         black.assert_stable(source, actual, DEFAULT_MODE)
1522
1523     def test_single_file_force_pyi(self) -> None:
1524         reg_mode = DEFAULT_MODE
1525         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1526         contents, expected = read_data("force_pyi")
1527         with cache_dir() as workspace:
1528             path = (workspace / "file.py").resolve()
1529             with open(path, "w") as fh:
1530                 fh.write(contents)
1531             self.invokeBlack([str(path), "--pyi"])
1532             with open(path, "r") as fh:
1533                 actual = fh.read()
1534             # verify cache with --pyi is separate
1535             pyi_cache = black.read_cache(pyi_mode)
1536             self.assertIn(path, pyi_cache)
1537             normal_cache = black.read_cache(reg_mode)
1538             self.assertNotIn(path, normal_cache)
1539         self.assertEqual(actual, expected)
1540
1541     @event_loop()
1542     def test_multi_file_force_pyi(self) -> None:
1543         reg_mode = DEFAULT_MODE
1544         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1545         contents, expected = read_data("force_pyi")
1546         with cache_dir() as workspace:
1547             paths = [
1548                 (workspace / "file1.py").resolve(),
1549                 (workspace / "file2.py").resolve(),
1550             ]
1551             for path in paths:
1552                 with open(path, "w") as fh:
1553                     fh.write(contents)
1554             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1555             for path in paths:
1556                 with open(path, "r") as fh:
1557                     actual = fh.read()
1558                 self.assertEqual(actual, expected)
1559             # verify cache with --pyi is separate
1560             pyi_cache = black.read_cache(pyi_mode)
1561             normal_cache = black.read_cache(reg_mode)
1562             for path in paths:
1563                 self.assertIn(path, pyi_cache)
1564                 self.assertNotIn(path, normal_cache)
1565
1566     def test_pipe_force_pyi(self) -> None:
1567         source, expected = read_data("force_pyi")
1568         result = CliRunner().invoke(
1569             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1570         )
1571         self.assertEqual(result.exit_code, 0)
1572         actual = result.output
1573         self.assertFormatEqual(actual, expected)
1574
1575     def test_single_file_force_py36(self) -> None:
1576         reg_mode = DEFAULT_MODE
1577         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1578         source, expected = read_data("force_py36")
1579         with cache_dir() as workspace:
1580             path = (workspace / "file.py").resolve()
1581             with open(path, "w") as fh:
1582                 fh.write(source)
1583             self.invokeBlack([str(path), *PY36_ARGS])
1584             with open(path, "r") as fh:
1585                 actual = fh.read()
1586             # verify cache with --target-version is separate
1587             py36_cache = black.read_cache(py36_mode)
1588             self.assertIn(path, py36_cache)
1589             normal_cache = black.read_cache(reg_mode)
1590             self.assertNotIn(path, normal_cache)
1591         self.assertEqual(actual, expected)
1592
1593     @event_loop()
1594     def test_multi_file_force_py36(self) -> None:
1595         reg_mode = DEFAULT_MODE
1596         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1597         source, expected = read_data("force_py36")
1598         with cache_dir() as workspace:
1599             paths = [
1600                 (workspace / "file1.py").resolve(),
1601                 (workspace / "file2.py").resolve(),
1602             ]
1603             for path in paths:
1604                 with open(path, "w") as fh:
1605                     fh.write(source)
1606             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1607             for path in paths:
1608                 with open(path, "r") as fh:
1609                     actual = fh.read()
1610                 self.assertEqual(actual, expected)
1611             # verify cache with --target-version is separate
1612             pyi_cache = black.read_cache(py36_mode)
1613             normal_cache = black.read_cache(reg_mode)
1614             for path in paths:
1615                 self.assertIn(path, pyi_cache)
1616                 self.assertNotIn(path, normal_cache)
1617
1618     def test_collections(self) -> None:
1619         source, expected = read_data("collections")
1620         actual = fs(source)
1621         self.assertFormatEqual(expected, actual)
1622         black.assert_equivalent(source, actual)
1623         black.assert_stable(source, actual, DEFAULT_MODE)
1624
1625     def test_pipe_force_py36(self) -> None:
1626         source, expected = read_data("force_py36")
1627         result = CliRunner().invoke(
1628             black.main,
1629             ["-", "-q", "--target-version=py36"],
1630             input=BytesIO(source.encode("utf8")),
1631         )
1632         self.assertEqual(result.exit_code, 0)
1633         actual = result.output
1634         self.assertFormatEqual(actual, expected)
1635
1636     def test_include_exclude(self) -> None:
1637         path = THIS_DIR / "data" / "include_exclude_tests"
1638         include = re.compile(r"\.pyi?$")
1639         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1640         report = black.Report()
1641         gitignore = PathSpec.from_lines("gitwildmatch", [])
1642         sources: List[Path] = []
1643         expected = [
1644             Path(path / "b/dont_exclude/a.py"),
1645             Path(path / "b/dont_exclude/a.pyi"),
1646         ]
1647         this_abs = THIS_DIR.resolve()
1648         sources.extend(
1649             black.gen_python_files(
1650                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1651             )
1652         )
1653         self.assertEqual(sorted(expected), sorted(sources))
1654
1655     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1656     def test_exclude_for_issue_1572(self) -> None:
1657         # Exclude shouldn't touch files that were explicitly given to Black through the
1658         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1659         # https://github.com/psf/black/issues/1572
1660         path = THIS_DIR / "data" / "include_exclude_tests"
1661         include = ""
1662         exclude = r"/exclude/|a\.py"
1663         src = str(path / "b/exclude/a.py")
1664         report = black.Report()
1665         expected = [Path(path / "b/exclude/a.py")]
1666         sources = list(
1667             black.get_sources(
1668                 ctx=FakeContext(),
1669                 src=(src,),
1670                 quiet=True,
1671                 verbose=False,
1672                 include=include,
1673                 exclude=exclude,
1674                 force_exclude=None,
1675                 report=report,
1676             )
1677         )
1678         self.assertEqual(sorted(expected), sorted(sources))
1679
1680     def test_gitignore_exclude(self) -> None:
1681         path = THIS_DIR / "data" / "include_exclude_tests"
1682         include = re.compile(r"\.pyi?$")
1683         exclude = re.compile(r"")
1684         report = black.Report()
1685         gitignore = PathSpec.from_lines(
1686             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1687         )
1688         sources: List[Path] = []
1689         expected = [
1690             Path(path / "b/dont_exclude/a.py"),
1691             Path(path / "b/dont_exclude/a.pyi"),
1692         ]
1693         this_abs = THIS_DIR.resolve()
1694         sources.extend(
1695             black.gen_python_files(
1696                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1697             )
1698         )
1699         self.assertEqual(sorted(expected), sorted(sources))
1700
1701     def test_empty_include(self) -> None:
1702         path = THIS_DIR / "data" / "include_exclude_tests"
1703         report = black.Report()
1704         gitignore = PathSpec.from_lines("gitwildmatch", [])
1705         empty = re.compile(r"")
1706         sources: List[Path] = []
1707         expected = [
1708             Path(path / "b/exclude/a.pie"),
1709             Path(path / "b/exclude/a.py"),
1710             Path(path / "b/exclude/a.pyi"),
1711             Path(path / "b/dont_exclude/a.pie"),
1712             Path(path / "b/dont_exclude/a.py"),
1713             Path(path / "b/dont_exclude/a.pyi"),
1714             Path(path / "b/.definitely_exclude/a.pie"),
1715             Path(path / "b/.definitely_exclude/a.py"),
1716             Path(path / "b/.definitely_exclude/a.pyi"),
1717         ]
1718         this_abs = THIS_DIR.resolve()
1719         sources.extend(
1720             black.gen_python_files(
1721                 path.iterdir(),
1722                 this_abs,
1723                 empty,
1724                 re.compile(black.DEFAULT_EXCLUDES),
1725                 None,
1726                 report,
1727                 gitignore,
1728             )
1729         )
1730         self.assertEqual(sorted(expected), sorted(sources))
1731
1732     def test_empty_exclude(self) -> None:
1733         path = THIS_DIR / "data" / "include_exclude_tests"
1734         report = black.Report()
1735         gitignore = PathSpec.from_lines("gitwildmatch", [])
1736         empty = re.compile(r"")
1737         sources: List[Path] = []
1738         expected = [
1739             Path(path / "b/dont_exclude/a.py"),
1740             Path(path / "b/dont_exclude/a.pyi"),
1741             Path(path / "b/exclude/a.py"),
1742             Path(path / "b/exclude/a.pyi"),
1743             Path(path / "b/.definitely_exclude/a.py"),
1744             Path(path / "b/.definitely_exclude/a.pyi"),
1745         ]
1746         this_abs = THIS_DIR.resolve()
1747         sources.extend(
1748             black.gen_python_files(
1749                 path.iterdir(),
1750                 this_abs,
1751                 re.compile(black.DEFAULT_INCLUDES),
1752                 empty,
1753                 None,
1754                 report,
1755                 gitignore,
1756             )
1757         )
1758         self.assertEqual(sorted(expected), sorted(sources))
1759
1760     def test_invalid_include_exclude(self) -> None:
1761         for option in ["--include", "--exclude"]:
1762             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1763
1764     def test_preserves_line_endings(self) -> None:
1765         with TemporaryDirectory() as workspace:
1766             test_file = Path(workspace) / "test.py"
1767             for nl in ["\n", "\r\n"]:
1768                 contents = nl.join(["def f(  ):", "    pass"])
1769                 test_file.write_bytes(contents.encode())
1770                 ff(test_file, write_back=black.WriteBack.YES)
1771                 updated_contents: bytes = test_file.read_bytes()
1772                 self.assertIn(nl.encode(), updated_contents)
1773                 if nl == "\n":
1774                     self.assertNotIn(b"\r\n", updated_contents)
1775
1776     def test_preserves_line_endings_via_stdin(self) -> None:
1777         for nl in ["\n", "\r\n"]:
1778             contents = nl.join(["def f(  ):", "    pass"])
1779             runner = BlackRunner()
1780             result = runner.invoke(
1781                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1782             )
1783             self.assertEqual(result.exit_code, 0)
1784             output = runner.stdout_bytes
1785             self.assertIn(nl.encode("utf8"), output)
1786             if nl == "\n":
1787                 self.assertNotIn(b"\r\n", output)
1788
1789     def test_assert_equivalent_different_asts(self) -> None:
1790         with self.assertRaises(AssertionError):
1791             black.assert_equivalent("{}", "None")
1792
1793     def test_symlink_out_of_root_directory(self) -> None:
1794         path = MagicMock()
1795         root = THIS_DIR.resolve()
1796         child = MagicMock()
1797         include = re.compile(black.DEFAULT_INCLUDES)
1798         exclude = re.compile(black.DEFAULT_EXCLUDES)
1799         report = black.Report()
1800         gitignore = PathSpec.from_lines("gitwildmatch", [])
1801         # `child` should behave like a symlink which resolved path is clearly
1802         # outside of the `root` directory.
1803         path.iterdir.return_value = [child]
1804         child.resolve.return_value = Path("/a/b/c")
1805         child.as_posix.return_value = "/a/b/c"
1806         child.is_symlink.return_value = True
1807         try:
1808             list(
1809                 black.gen_python_files(
1810                     path.iterdir(), root, include, exclude, None, report, gitignore
1811                 )
1812             )
1813         except ValueError as ve:
1814             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1815         path.iterdir.assert_called_once()
1816         child.resolve.assert_called_once()
1817         child.is_symlink.assert_called_once()
1818         # `child` should behave like a strange file which resolved path is clearly
1819         # outside of the `root` directory.
1820         child.is_symlink.return_value = False
1821         with self.assertRaises(ValueError):
1822             list(
1823                 black.gen_python_files(
1824                     path.iterdir(), root, include, exclude, None, report, gitignore
1825                 )
1826             )
1827         path.iterdir.assert_called()
1828         self.assertEqual(path.iterdir.call_count, 2)
1829         child.resolve.assert_called()
1830         self.assertEqual(child.resolve.call_count, 2)
1831         child.is_symlink.assert_called()
1832         self.assertEqual(child.is_symlink.call_count, 2)
1833
1834     def test_shhh_click(self) -> None:
1835         try:
1836             from click import _unicodefun  # type: ignore
1837         except ModuleNotFoundError:
1838             self.skipTest("Incompatible Click version")
1839         if not hasattr(_unicodefun, "_verify_python3_env"):
1840             self.skipTest("Incompatible Click version")
1841         # First, let's see if Click is crashing with a preferred ASCII charset.
1842         with patch("locale.getpreferredencoding") as gpe:
1843             gpe.return_value = "ASCII"
1844             with self.assertRaises(RuntimeError):
1845                 _unicodefun._verify_python3_env()
1846         # Now, let's silence Click...
1847         black.patch_click()
1848         # ...and confirm it's silent.
1849         with patch("locale.getpreferredencoding") as gpe:
1850             gpe.return_value = "ASCII"
1851             try:
1852                 _unicodefun._verify_python3_env()
1853             except RuntimeError as re:
1854                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1855
1856     def test_root_logger_not_used_directly(self) -> None:
1857         def fail(*args: Any, **kwargs: Any) -> None:
1858             self.fail("Record created with root logger")
1859
1860         with patch.multiple(
1861             logging.root,
1862             debug=fail,
1863             info=fail,
1864             warning=fail,
1865             error=fail,
1866             critical=fail,
1867             log=fail,
1868         ):
1869             ff(THIS_FILE)
1870
1871     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1872     def test_blackd_main(self) -> None:
1873         with patch("blackd.web.run_app"):
1874             result = CliRunner().invoke(blackd.main, [])
1875             if result.exception is not None:
1876                 raise result.exception
1877             self.assertEqual(result.exit_code, 0)
1878
1879     def test_invalid_config_return_code(self) -> None:
1880         tmp_file = Path(black.dump_to_file())
1881         try:
1882             tmp_config = Path(black.dump_to_file())
1883             tmp_config.unlink()
1884             args = ["--config", str(tmp_config), str(tmp_file)]
1885             self.invokeBlack(args, exit_code=2, ignore_config=False)
1886         finally:
1887             tmp_file.unlink()
1888
1889     def test_parse_pyproject_toml(self) -> None:
1890         test_toml_file = THIS_DIR / "test.toml"
1891         config = black.parse_pyproject_toml(str(test_toml_file))
1892         self.assertEqual(config["verbose"], 1)
1893         self.assertEqual(config["check"], "no")
1894         self.assertEqual(config["diff"], "y")
1895         self.assertEqual(config["color"], True)
1896         self.assertEqual(config["line_length"], 79)
1897         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1898         self.assertEqual(config["exclude"], r"\.pyi?$")
1899         self.assertEqual(config["include"], r"\.py?$")
1900
1901     def test_read_pyproject_toml(self) -> None:
1902         test_toml_file = THIS_DIR / "test.toml"
1903         fake_ctx = FakeContext()
1904         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1905         config = fake_ctx.default_map
1906         self.assertEqual(config["verbose"], "1")
1907         self.assertEqual(config["check"], "no")
1908         self.assertEqual(config["diff"], "y")
1909         self.assertEqual(config["color"], "True")
1910         self.assertEqual(config["line_length"], "79")
1911         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1912         self.assertEqual(config["exclude"], r"\.pyi?$")
1913         self.assertEqual(config["include"], r"\.py?$")
1914
1915     def test_find_project_root(self) -> None:
1916         with TemporaryDirectory() as workspace:
1917             root = Path(workspace)
1918             test_dir = root / "test"
1919             test_dir.mkdir()
1920
1921             src_dir = root / "src"
1922             src_dir.mkdir()
1923
1924             root_pyproject = root / "pyproject.toml"
1925             root_pyproject.touch()
1926             src_pyproject = src_dir / "pyproject.toml"
1927             src_pyproject.touch()
1928             src_python = src_dir / "foo.py"
1929             src_python.touch()
1930
1931             self.assertEqual(
1932                 black.find_project_root((src_dir, test_dir)), root.resolve()
1933             )
1934             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1935             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1936
1937
1938 class BlackDTestCase(AioHTTPTestCase):
1939     async def get_application(self) -> web.Application:
1940         return blackd.make_app()
1941
1942     # TODO: remove these decorators once the below is released
1943     # https://github.com/aio-libs/aiohttp/pull/3727
1944     @skip_if_exception("ClientOSError")
1945     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1946     @unittest_run_loop
1947     async def test_blackd_request_needs_formatting(self) -> None:
1948         response = await self.client.post("/", data=b"print('hello world')")
1949         self.assertEqual(response.status, 200)
1950         self.assertEqual(response.charset, "utf8")
1951         self.assertEqual(await response.read(), b'print("hello world")\n')
1952
1953     @skip_if_exception("ClientOSError")
1954     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1955     @unittest_run_loop
1956     async def test_blackd_request_no_change(self) -> None:
1957         response = await self.client.post("/", data=b'print("hello world")\n')
1958         self.assertEqual(response.status, 204)
1959         self.assertEqual(await response.read(), b"")
1960
1961     @skip_if_exception("ClientOSError")
1962     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1963     @unittest_run_loop
1964     async def test_blackd_request_syntax_error(self) -> None:
1965         response = await self.client.post("/", data=b"what even ( is")
1966         self.assertEqual(response.status, 400)
1967         content = await response.text()
1968         self.assertTrue(
1969             content.startswith("Cannot parse"),
1970             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1971         )
1972
1973     @skip_if_exception("ClientOSError")
1974     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1975     @unittest_run_loop
1976     async def test_blackd_unsupported_version(self) -> None:
1977         response = await self.client.post(
1978             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1979         )
1980         self.assertEqual(response.status, 501)
1981
1982     @skip_if_exception("ClientOSError")
1983     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1984     @unittest_run_loop
1985     async def test_blackd_supported_version(self) -> None:
1986         response = await self.client.post(
1987             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1988         )
1989         self.assertEqual(response.status, 200)
1990
1991     @skip_if_exception("ClientOSError")
1992     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1993     @unittest_run_loop
1994     async def test_blackd_invalid_python_variant(self) -> None:
1995         async def check(header_value: str, expected_status: int = 400) -> None:
1996             response = await self.client.post(
1997                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1998             )
1999             self.assertEqual(response.status, expected_status)
2000
2001         await check("lol")
2002         await check("ruby3.5")
2003         await check("pyi3.6")
2004         await check("py1.5")
2005         await check("2.8")
2006         await check("py2.8")
2007         await check("3.0")
2008         await check("pypy3.0")
2009         await check("jython3.4")
2010
2011     @skip_if_exception("ClientOSError")
2012     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2013     @unittest_run_loop
2014     async def test_blackd_pyi(self) -> None:
2015         source, expected = read_data("stub.pyi")
2016         response = await self.client.post(
2017             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
2018         )
2019         self.assertEqual(response.status, 200)
2020         self.assertEqual(await response.text(), expected)
2021
2022     @skip_if_exception("ClientOSError")
2023     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2024     @unittest_run_loop
2025     async def test_blackd_diff(self) -> None:
2026         diff_header = re.compile(
2027             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2028         )
2029
2030         source, _ = read_data("blackd_diff.py")
2031         expected, _ = read_data("blackd_diff.diff")
2032
2033         response = await self.client.post(
2034             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2035         )
2036         self.assertEqual(response.status, 200)
2037
2038         actual = await response.text()
2039         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2040         self.assertEqual(actual, expected)
2041
2042     @skip_if_exception("ClientOSError")
2043     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2044     @unittest_run_loop
2045     async def test_blackd_python_variant(self) -> None:
2046         code = (
2047             "def f(\n"
2048             "    and_has_a_bunch_of,\n"
2049             "    very_long_arguments_too,\n"
2050             "    and_lots_of_them_as_well_lol,\n"
2051             "    **and_very_long_keyword_arguments\n"
2052             "):\n"
2053             "    pass\n"
2054         )
2055
2056         async def check(header_value: str, expected_status: int) -> None:
2057             response = await self.client.post(
2058                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2059             )
2060             self.assertEqual(
2061                 response.status, expected_status, msg=await response.text()
2062             )
2063
2064         await check("3.6", 200)
2065         await check("py3.6", 200)
2066         await check("3.6,3.7", 200)
2067         await check("3.6,py3.7", 200)
2068         await check("py36,py37", 200)
2069         await check("36", 200)
2070         await check("3.6.4", 200)
2071
2072         await check("2", 204)
2073         await check("2.7", 204)
2074         await check("py2.7", 204)
2075         await check("3.4", 204)
2076         await check("py3.4", 204)
2077         await check("py34,py36", 204)
2078         await check("34", 204)
2079
2080     @skip_if_exception("ClientOSError")
2081     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2082     @unittest_run_loop
2083     async def test_blackd_line_length(self) -> None:
2084         response = await self.client.post(
2085             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2086         )
2087         self.assertEqual(response.status, 200)
2088
2089     @skip_if_exception("ClientOSError")
2090     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2091     @unittest_run_loop
2092     async def test_blackd_invalid_line_length(self) -> None:
2093         response = await self.client.post(
2094             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2095         )
2096         self.assertEqual(response.status, 400)
2097
2098     @skip_if_exception("ClientOSError")
2099     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2100     @unittest_run_loop
2101     async def test_blackd_response_black_version_header(self) -> None:
2102         response = await self.client.post("/")
2103         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2104
2105
2106 with open(black.__file__, "r", encoding="utf-8") as _bf:
2107     black_source_lines = _bf.readlines()
2108
2109
2110 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2111     """Show function calls `from black/__init__.py` as they happen.
2112
2113     Register this with `sys.settrace()` in a test you're debugging.
2114     """
2115     if event != "call":
2116         return tracefunc
2117
2118     stack = len(inspect.stack()) - 19
2119     stack *= 2
2120     filename = frame.f_code.co_filename
2121     lineno = frame.f_lineno
2122     func_sig_lineno = lineno - 1
2123     funcname = black_source_lines[func_sig_lineno].strip()
2124     while funcname.startswith("@"):
2125         func_sig_lineno += 1
2126         funcname = black_source_lines[func_sig_lineno].strip()
2127     if "black/__init__.py" in filename:
2128         print(f"{' ' * stack}{lineno}:{funcname}")
2129     return tracefunc
2130
2131
2132 if __name__ == "__main__":
2133     unittest.main(module="test_black")