]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

cf311f52e14d813de8eff44e6394f9463cd621d9
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def _test_wip(self) -> None:
350         source, expected = read_data("wip")
351         sys.settrace(tracefunc)
352         mode = replace(
353             DEFAULT_MODE,
354             experimental_string_processing=False,
355             target_versions={black.TargetVersion.PY38},
356         )
357         actual = fs(source, mode=mode)
358         sys.settrace(None)
359         self.assertFormatEqual(expected, actual)
360         black.assert_equivalent(source, actual)
361         black.assert_stable(source, actual, black.FileMode())
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_function_trailing_comma(self) -> None:
365         source, expected = read_data("function_trailing_comma")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_equivalent(source, actual)
369         black.assert_stable(source, actual, DEFAULT_MODE)
370
371     @patch("black.dump_to_file", dump_to_stderr)
372     def test_expression(self) -> None:
373         source, expected = read_data("expression")
374         actual = fs(source)
375         self.assertFormatEqual(expected, actual)
376         black.assert_equivalent(source, actual)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_pep_572(self) -> None:
381         source, expected = read_data("pep_572")
382         actual = fs(source)
383         self.assertFormatEqual(expected, actual)
384         black.assert_stable(source, actual, DEFAULT_MODE)
385         if sys.version_info >= (3, 8):
386             black.assert_equivalent(source, actual)
387
388     def test_pep_572_version_detection(self) -> None:
389         source, _ = read_data("pep_572")
390         root = black.lib2to3_parse(source)
391         features = black.get_features_used(root)
392         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
393         versions = black.detect_target_versions(root)
394         self.assertIn(black.TargetVersion.PY38, versions)
395
396     def test_expression_ff(self) -> None:
397         source, expected = read_data("expression")
398         tmp_file = Path(black.dump_to_file(source))
399         try:
400             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
401             with open(tmp_file, encoding="utf8") as f:
402                 actual = f.read()
403         finally:
404             os.unlink(tmp_file)
405         self.assertFormatEqual(expected, actual)
406         with patch("black.dump_to_file", dump_to_stderr):
407             black.assert_equivalent(source, actual)
408             black.assert_stable(source, actual, DEFAULT_MODE)
409
410     def test_expression_diff(self) -> None:
411         source, _ = read_data("expression.py")
412         expected, _ = read_data("expression.diff")
413         tmp_file = Path(black.dump_to_file(source))
414         diff_header = re.compile(
415             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
416             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
417         )
418         try:
419             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
420             self.assertEqual(result.exit_code, 0)
421         finally:
422             os.unlink(tmp_file)
423         actual = result.output
424         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
425         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
426         if expected != actual:
427             dump = black.dump_to_file(actual)
428             msg = (
429                 "Expected diff isn't equal to the actual. If you made changes to"
430                 " expression.py and this is an anticipated difference, overwrite"
431                 f" tests/data/expression.diff with {dump}"
432             )
433             self.assertEqual(expected, actual, msg)
434
435     def test_expression_diff_with_color(self) -> None:
436         source, _ = read_data("expression.py")
437         expected, _ = read_data("expression.diff")
438         tmp_file = Path(black.dump_to_file(source))
439         try:
440             result = BlackRunner().invoke(
441                 black.main, ["--diff", "--color", str(tmp_file)]
442             )
443         finally:
444             os.unlink(tmp_file)
445         actual = result.output
446         # We check the contents of the diff in `test_expression_diff`. All
447         # we need to check here is that color codes exist in the result.
448         self.assertIn("\033[1;37m", actual)
449         self.assertIn("\033[36m", actual)
450         self.assertIn("\033[32m", actual)
451         self.assertIn("\033[31m", actual)
452         self.assertIn("\033[0m", actual)
453
454     @patch("black.dump_to_file", dump_to_stderr)
455     def test_fstring(self) -> None:
456         source, expected = read_data("fstring")
457         actual = fs(source)
458         self.assertFormatEqual(expected, actual)
459         black.assert_equivalent(source, actual)
460         black.assert_stable(source, actual, DEFAULT_MODE)
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_pep_570(self) -> None:
464         source, expected = read_data("pep_570")
465         actual = fs(source)
466         self.assertFormatEqual(expected, actual)
467         black.assert_stable(source, actual, DEFAULT_MODE)
468         if sys.version_info >= (3, 8):
469             black.assert_equivalent(source, actual)
470
471     def test_detect_pos_only_arguments(self) -> None:
472         source, _ = read_data("pep_570")
473         root = black.lib2to3_parse(source)
474         features = black.get_features_used(root)
475         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
476         versions = black.detect_target_versions(root)
477         self.assertIn(black.TargetVersion.PY38, versions)
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_string_quotes(self) -> None:
481         source, expected = read_data("string_quotes")
482         actual = fs(source)
483         self.assertFormatEqual(expected, actual)
484         black.assert_equivalent(source, actual)
485         black.assert_stable(source, actual, DEFAULT_MODE)
486         mode = replace(DEFAULT_MODE, string_normalization=False)
487         not_normalized = fs(source, mode=mode)
488         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
489         black.assert_equivalent(source, not_normalized)
490         black.assert_stable(source, not_normalized, mode=mode)
491
492     @patch("black.dump_to_file", dump_to_stderr)
493     def test_docstring(self) -> None:
494         source, expected = read_data("docstring")
495         actual = fs(source)
496         self.assertFormatEqual(expected, actual)
497         black.assert_equivalent(source, actual)
498         black.assert_stable(source, actual, DEFAULT_MODE)
499         mode = replace(DEFAULT_MODE, string_normalization=False)
500         not_normalized = fs(source, mode=mode)
501         self.assertFormatEqual(expected, not_normalized)
502         black.assert_equivalent(source, not_normalized)
503         black.assert_stable(source, not_normalized, mode=mode)
504
505     def test_long_strings(self) -> None:
506         """Tests for splitting long strings."""
507         source, expected = read_data("long_strings")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, DEFAULT_MODE)
512
513     def test_long_strings_flag_disabled(self) -> None:
514         """Tests for turning off the string processing logic."""
515         source, expected = read_data("long_strings_flag_disabled")
516         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
517         actual = fs(source, mode=mode)
518         self.assertFormatEqual(expected, actual)
519         black.assert_stable(expected, actual, mode)
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_long_strings__edge_case(self) -> None:
523         """Edge-case tests for splitting long strings."""
524         source, expected = read_data("long_strings__edge_case")
525         actual = fs(source)
526         self.assertFormatEqual(expected, actual)
527         black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, DEFAULT_MODE)
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_long_strings__regression(self) -> None:
532         """Regression tests for splitting long strings."""
533         source, expected = read_data("long_strings__regression")
534         actual = fs(source)
535         self.assertFormatEqual(expected, actual)
536         black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, DEFAULT_MODE)
538
539     @patch("black.dump_to_file", dump_to_stderr)
540     def test_slices(self) -> None:
541         source, expected = read_data("slices")
542         actual = fs(source)
543         self.assertFormatEqual(expected, actual)
544         black.assert_equivalent(source, actual)
545         black.assert_stable(source, actual, DEFAULT_MODE)
546
547     @patch("black.dump_to_file", dump_to_stderr)
548     def test_percent_precedence(self) -> None:
549         source, expected = read_data("percent_precedence")
550         actual = fs(source)
551         self.assertFormatEqual(expected, actual)
552         black.assert_equivalent(source, actual)
553         black.assert_stable(source, actual, DEFAULT_MODE)
554
555     @patch("black.dump_to_file", dump_to_stderr)
556     def test_comments(self) -> None:
557         source, expected = read_data("comments")
558         actual = fs(source)
559         self.assertFormatEqual(expected, actual)
560         black.assert_equivalent(source, actual)
561         black.assert_stable(source, actual, DEFAULT_MODE)
562
563     @patch("black.dump_to_file", dump_to_stderr)
564     def test_comments2(self) -> None:
565         source, expected = read_data("comments2")
566         actual = fs(source)
567         self.assertFormatEqual(expected, actual)
568         black.assert_equivalent(source, actual)
569         black.assert_stable(source, actual, DEFAULT_MODE)
570
571     @patch("black.dump_to_file", dump_to_stderr)
572     def test_comments3(self) -> None:
573         source, expected = read_data("comments3")
574         actual = fs(source)
575         self.assertFormatEqual(expected, actual)
576         black.assert_equivalent(source, actual)
577         black.assert_stable(source, actual, DEFAULT_MODE)
578
579     @patch("black.dump_to_file", dump_to_stderr)
580     def test_comments4(self) -> None:
581         source, expected = read_data("comments4")
582         actual = fs(source)
583         self.assertFormatEqual(expected, actual)
584         black.assert_equivalent(source, actual)
585         black.assert_stable(source, actual, DEFAULT_MODE)
586
587     @patch("black.dump_to_file", dump_to_stderr)
588     def test_comments5(self) -> None:
589         source, expected = read_data("comments5")
590         actual = fs(source)
591         self.assertFormatEqual(expected, actual)
592         black.assert_equivalent(source, actual)
593         black.assert_stable(source, actual, DEFAULT_MODE)
594
595     @patch("black.dump_to_file", dump_to_stderr)
596     def test_comments6(self) -> None:
597         source, expected = read_data("comments6")
598         actual = fs(source)
599         self.assertFormatEqual(expected, actual)
600         black.assert_equivalent(source, actual)
601         black.assert_stable(source, actual, DEFAULT_MODE)
602
603     @patch("black.dump_to_file", dump_to_stderr)
604     def test_comments7(self) -> None:
605         source, expected = read_data("comments7")
606         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
607         actual = fs(source, mode=mode)
608         self.assertFormatEqual(expected, actual)
609         black.assert_equivalent(source, actual)
610         black.assert_stable(source, actual, DEFAULT_MODE)
611
612     @patch("black.dump_to_file", dump_to_stderr)
613     def test_comment_after_escaped_newline(self) -> None:
614         source, expected = read_data("comment_after_escaped_newline")
615         actual = fs(source)
616         self.assertFormatEqual(expected, actual)
617         black.assert_equivalent(source, actual)
618         black.assert_stable(source, actual, DEFAULT_MODE)
619
620     @patch("black.dump_to_file", dump_to_stderr)
621     def test_cantfit(self) -> None:
622         source, expected = read_data("cantfit")
623         actual = fs(source)
624         self.assertFormatEqual(expected, actual)
625         black.assert_equivalent(source, actual)
626         black.assert_stable(source, actual, DEFAULT_MODE)
627
628     @patch("black.dump_to_file", dump_to_stderr)
629     def test_import_spacing(self) -> None:
630         source, expected = read_data("import_spacing")
631         actual = fs(source)
632         self.assertFormatEqual(expected, actual)
633         black.assert_equivalent(source, actual)
634         black.assert_stable(source, actual, DEFAULT_MODE)
635
636     @patch("black.dump_to_file", dump_to_stderr)
637     def test_composition(self) -> None:
638         source, expected = read_data("composition")
639         actual = fs(source)
640         self.assertFormatEqual(expected, actual)
641         black.assert_equivalent(source, actual)
642         black.assert_stable(source, actual, DEFAULT_MODE)
643
644     @patch("black.dump_to_file", dump_to_stderr)
645     def test_composition_no_trailing_comma(self) -> None:
646         source, expected = read_data("composition_no_trailing_comma")
647         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
648         actual = fs(source, mode=mode)
649         self.assertFormatEqual(expected, actual)
650         black.assert_equivalent(source, actual)
651         black.assert_stable(source, actual, DEFAULT_MODE)
652
653     @patch("black.dump_to_file", dump_to_stderr)
654     def test_empty_lines(self) -> None:
655         source, expected = read_data("empty_lines")
656         actual = fs(source)
657         self.assertFormatEqual(expected, actual)
658         black.assert_equivalent(source, actual)
659         black.assert_stable(source, actual, DEFAULT_MODE)
660
661     @patch("black.dump_to_file", dump_to_stderr)
662     def test_remove_parens(self) -> None:
663         source, expected = read_data("remove_parens")
664         actual = fs(source)
665         self.assertFormatEqual(expected, actual)
666         black.assert_equivalent(source, actual)
667         black.assert_stable(source, actual, DEFAULT_MODE)
668
669     @patch("black.dump_to_file", dump_to_stderr)
670     def test_string_prefixes(self) -> None:
671         source, expected = read_data("string_prefixes")
672         actual = fs(source)
673         self.assertFormatEqual(expected, actual)
674         black.assert_equivalent(source, actual)
675         black.assert_stable(source, actual, DEFAULT_MODE)
676
677     @patch("black.dump_to_file", dump_to_stderr)
678     def test_numeric_literals(self) -> None:
679         source, expected = read_data("numeric_literals")
680         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
681         actual = fs(source, mode=mode)
682         self.assertFormatEqual(expected, actual)
683         black.assert_equivalent(source, actual)
684         black.assert_stable(source, actual, mode)
685
686     @patch("black.dump_to_file", dump_to_stderr)
687     def test_numeric_literals_ignoring_underscores(self) -> None:
688         source, expected = read_data("numeric_literals_skip_underscores")
689         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
690         actual = fs(source, mode=mode)
691         self.assertFormatEqual(expected, actual)
692         black.assert_equivalent(source, actual)
693         black.assert_stable(source, actual, mode)
694
695     @patch("black.dump_to_file", dump_to_stderr)
696     def test_numeric_literals_py2(self) -> None:
697         source, expected = read_data("numeric_literals_py2")
698         actual = fs(source)
699         self.assertFormatEqual(expected, actual)
700         black.assert_stable(source, actual, DEFAULT_MODE)
701
702     @patch("black.dump_to_file", dump_to_stderr)
703     def test_python2(self) -> None:
704         source, expected = read_data("python2")
705         actual = fs(source)
706         self.assertFormatEqual(expected, actual)
707         black.assert_equivalent(source, actual)
708         black.assert_stable(source, actual, DEFAULT_MODE)
709
710     @patch("black.dump_to_file", dump_to_stderr)
711     def test_python2_print_function(self) -> None:
712         source, expected = read_data("python2_print_function")
713         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
714         actual = fs(source, mode=mode)
715         self.assertFormatEqual(expected, actual)
716         black.assert_equivalent(source, actual)
717         black.assert_stable(source, actual, mode)
718
719     @patch("black.dump_to_file", dump_to_stderr)
720     def test_python2_unicode_literals(self) -> None:
721         source, expected = read_data("python2_unicode_literals")
722         actual = fs(source)
723         self.assertFormatEqual(expected, actual)
724         black.assert_equivalent(source, actual)
725         black.assert_stable(source, actual, DEFAULT_MODE)
726
727     @patch("black.dump_to_file", dump_to_stderr)
728     def test_stub(self) -> None:
729         mode = replace(DEFAULT_MODE, is_pyi=True)
730         source, expected = read_data("stub.pyi")
731         actual = fs(source, mode=mode)
732         self.assertFormatEqual(expected, actual)
733         black.assert_stable(source, actual, mode)
734
735     @patch("black.dump_to_file", dump_to_stderr)
736     def test_async_as_identifier(self) -> None:
737         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
738         source, expected = read_data("async_as_identifier")
739         actual = fs(source)
740         self.assertFormatEqual(expected, actual)
741         major, minor = sys.version_info[:2]
742         if major < 3 or (major <= 3 and minor < 7):
743             black.assert_equivalent(source, actual)
744         black.assert_stable(source, actual, DEFAULT_MODE)
745         # ensure black can parse this when the target is 3.6
746         self.invokeBlack([str(source_path), "--target-version", "py36"])
747         # but not on 3.7, because async/await is no longer an identifier
748         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
749
750     @patch("black.dump_to_file", dump_to_stderr)
751     def test_python37(self) -> None:
752         source_path = (THIS_DIR / "data" / "python37.py").resolve()
753         source, expected = read_data("python37")
754         actual = fs(source)
755         self.assertFormatEqual(expected, actual)
756         major, minor = sys.version_info[:2]
757         if major > 3 or (major == 3 and minor >= 7):
758             black.assert_equivalent(source, actual)
759         black.assert_stable(source, actual, DEFAULT_MODE)
760         # ensure black can parse this when the target is 3.7
761         self.invokeBlack([str(source_path), "--target-version", "py37"])
762         # but not on 3.6, because we use async as a reserved keyword
763         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
764
765     @patch("black.dump_to_file", dump_to_stderr)
766     def test_python38(self) -> None:
767         source, expected = read_data("python38")
768         actual = fs(source)
769         self.assertFormatEqual(expected, actual)
770         major, minor = sys.version_info[:2]
771         if major > 3 or (major == 3 and minor >= 8):
772             black.assert_equivalent(source, actual)
773         black.assert_stable(source, actual, DEFAULT_MODE)
774
775     @patch("black.dump_to_file", dump_to_stderr)
776     def test_fmtonoff(self) -> None:
777         source, expected = read_data("fmtonoff")
778         actual = fs(source)
779         self.assertFormatEqual(expected, actual)
780         black.assert_equivalent(source, actual)
781         black.assert_stable(source, actual, DEFAULT_MODE)
782
783     @patch("black.dump_to_file", dump_to_stderr)
784     def test_fmtonoff2(self) -> None:
785         source, expected = read_data("fmtonoff2")
786         actual = fs(source)
787         self.assertFormatEqual(expected, actual)
788         black.assert_equivalent(source, actual)
789         black.assert_stable(source, actual, DEFAULT_MODE)
790
791     @patch("black.dump_to_file", dump_to_stderr)
792     def test_fmtonoff3(self) -> None:
793         source, expected = read_data("fmtonoff3")
794         actual = fs(source)
795         self.assertFormatEqual(expected, actual)
796         black.assert_equivalent(source, actual)
797         black.assert_stable(source, actual, DEFAULT_MODE)
798
799     @patch("black.dump_to_file", dump_to_stderr)
800     def test_fmtonoff4(self) -> None:
801         source, expected = read_data("fmtonoff4")
802         actual = fs(source)
803         self.assertFormatEqual(expected, actual)
804         black.assert_equivalent(source, actual)
805         black.assert_stable(source, actual, DEFAULT_MODE)
806
807     @patch("black.dump_to_file", dump_to_stderr)
808     def test_remove_empty_parentheses_after_class(self) -> None:
809         source, expected = read_data("class_blank_parentheses")
810         actual = fs(source)
811         self.assertFormatEqual(expected, actual)
812         black.assert_equivalent(source, actual)
813         black.assert_stable(source, actual, DEFAULT_MODE)
814
815     @patch("black.dump_to_file", dump_to_stderr)
816     def test_new_line_between_class_and_code(self) -> None:
817         source, expected = read_data("class_methods_new_line")
818         actual = fs(source)
819         self.assertFormatEqual(expected, actual)
820         black.assert_equivalent(source, actual)
821         black.assert_stable(source, actual, DEFAULT_MODE)
822
823     @patch("black.dump_to_file", dump_to_stderr)
824     def test_bracket_match(self) -> None:
825         source, expected = read_data("bracketmatch")
826         actual = fs(source)
827         self.assertFormatEqual(expected, actual)
828         black.assert_equivalent(source, actual)
829         black.assert_stable(source, actual, DEFAULT_MODE)
830
831     @patch("black.dump_to_file", dump_to_stderr)
832     def test_tuple_assign(self) -> None:
833         source, expected = read_data("tupleassign")
834         actual = fs(source)
835         self.assertFormatEqual(expected, actual)
836         black.assert_equivalent(source, actual)
837         black.assert_stable(source, actual, DEFAULT_MODE)
838
839     @patch("black.dump_to_file", dump_to_stderr)
840     def test_beginning_backslash(self) -> None:
841         source, expected = read_data("beginning_backslash")
842         actual = fs(source)
843         self.assertFormatEqual(expected, actual)
844         black.assert_equivalent(source, actual)
845         black.assert_stable(source, actual, DEFAULT_MODE)
846
847     def test_tab_comment_indentation(self) -> None:
848         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
849         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
850         self.assertFormatEqual(contents_spc, fs(contents_spc))
851         self.assertFormatEqual(contents_spc, fs(contents_tab))
852
853         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
854         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
855         self.assertFormatEqual(contents_spc, fs(contents_spc))
856         self.assertFormatEqual(contents_spc, fs(contents_tab))
857
858         # mixed tabs and spaces (valid Python 2 code)
859         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
860         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
861         self.assertFormatEqual(contents_spc, fs(contents_spc))
862         self.assertFormatEqual(contents_spc, fs(contents_tab))
863
864         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
865         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
866         self.assertFormatEqual(contents_spc, fs(contents_spc))
867         self.assertFormatEqual(contents_spc, fs(contents_tab))
868
869     def test_report_verbose(self) -> None:
870         report = black.Report(verbose=True)
871         out_lines = []
872         err_lines = []
873
874         def out(msg: str, **kwargs: Any) -> None:
875             out_lines.append(msg)
876
877         def err(msg: str, **kwargs: Any) -> None:
878             err_lines.append(msg)
879
880         with patch("black.out", out), patch("black.err", err):
881             report.done(Path("f1"), black.Changed.NO)
882             self.assertEqual(len(out_lines), 1)
883             self.assertEqual(len(err_lines), 0)
884             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
885             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
886             self.assertEqual(report.return_code, 0)
887             report.done(Path("f2"), black.Changed.YES)
888             self.assertEqual(len(out_lines), 2)
889             self.assertEqual(len(err_lines), 0)
890             self.assertEqual(out_lines[-1], "reformatted f2")
891             self.assertEqual(
892                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
893             )
894             report.done(Path("f3"), black.Changed.CACHED)
895             self.assertEqual(len(out_lines), 3)
896             self.assertEqual(len(err_lines), 0)
897             self.assertEqual(
898                 out_lines[-1], "f3 wasn't modified on disk since last run."
899             )
900             self.assertEqual(
901                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
902             )
903             self.assertEqual(report.return_code, 0)
904             report.check = True
905             self.assertEqual(report.return_code, 1)
906             report.check = False
907             report.failed(Path("e1"), "boom")
908             self.assertEqual(len(out_lines), 3)
909             self.assertEqual(len(err_lines), 1)
910             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
911             self.assertEqual(
912                 unstyle(str(report)),
913                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
914                 " reformat.",
915             )
916             self.assertEqual(report.return_code, 123)
917             report.done(Path("f3"), black.Changed.YES)
918             self.assertEqual(len(out_lines), 4)
919             self.assertEqual(len(err_lines), 1)
920             self.assertEqual(out_lines[-1], "reformatted f3")
921             self.assertEqual(
922                 unstyle(str(report)),
923                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
924                 " reformat.",
925             )
926             self.assertEqual(report.return_code, 123)
927             report.failed(Path("e2"), "boom")
928             self.assertEqual(len(out_lines), 4)
929             self.assertEqual(len(err_lines), 2)
930             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
931             self.assertEqual(
932                 unstyle(str(report)),
933                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
934                 " reformat.",
935             )
936             self.assertEqual(report.return_code, 123)
937             report.path_ignored(Path("wat"), "no match")
938             self.assertEqual(len(out_lines), 5)
939             self.assertEqual(len(err_lines), 2)
940             self.assertEqual(out_lines[-1], "wat ignored: no match")
941             self.assertEqual(
942                 unstyle(str(report)),
943                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
944                 " reformat.",
945             )
946             self.assertEqual(report.return_code, 123)
947             report.done(Path("f4"), black.Changed.NO)
948             self.assertEqual(len(out_lines), 6)
949             self.assertEqual(len(err_lines), 2)
950             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
951             self.assertEqual(
952                 unstyle(str(report)),
953                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
954                 " reformat.",
955             )
956             self.assertEqual(report.return_code, 123)
957             report.check = True
958             self.assertEqual(
959                 unstyle(str(report)),
960                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
961                 " would fail to reformat.",
962             )
963             report.check = False
964             report.diff = True
965             self.assertEqual(
966                 unstyle(str(report)),
967                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
968                 " would fail to reformat.",
969             )
970
971     def test_report_quiet(self) -> None:
972         report = black.Report(quiet=True)
973         out_lines = []
974         err_lines = []
975
976         def out(msg: str, **kwargs: Any) -> None:
977             out_lines.append(msg)
978
979         def err(msg: str, **kwargs: Any) -> None:
980             err_lines.append(msg)
981
982         with patch("black.out", out), patch("black.err", err):
983             report.done(Path("f1"), black.Changed.NO)
984             self.assertEqual(len(out_lines), 0)
985             self.assertEqual(len(err_lines), 0)
986             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
987             self.assertEqual(report.return_code, 0)
988             report.done(Path("f2"), black.Changed.YES)
989             self.assertEqual(len(out_lines), 0)
990             self.assertEqual(len(err_lines), 0)
991             self.assertEqual(
992                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
993             )
994             report.done(Path("f3"), black.Changed.CACHED)
995             self.assertEqual(len(out_lines), 0)
996             self.assertEqual(len(err_lines), 0)
997             self.assertEqual(
998                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
999             )
1000             self.assertEqual(report.return_code, 0)
1001             report.check = True
1002             self.assertEqual(report.return_code, 1)
1003             report.check = False
1004             report.failed(Path("e1"), "boom")
1005             self.assertEqual(len(out_lines), 0)
1006             self.assertEqual(len(err_lines), 1)
1007             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1008             self.assertEqual(
1009                 unstyle(str(report)),
1010                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1011                 " reformat.",
1012             )
1013             self.assertEqual(report.return_code, 123)
1014             report.done(Path("f3"), black.Changed.YES)
1015             self.assertEqual(len(out_lines), 0)
1016             self.assertEqual(len(err_lines), 1)
1017             self.assertEqual(
1018                 unstyle(str(report)),
1019                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1020                 " reformat.",
1021             )
1022             self.assertEqual(report.return_code, 123)
1023             report.failed(Path("e2"), "boom")
1024             self.assertEqual(len(out_lines), 0)
1025             self.assertEqual(len(err_lines), 2)
1026             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1027             self.assertEqual(
1028                 unstyle(str(report)),
1029                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1030                 " reformat.",
1031             )
1032             self.assertEqual(report.return_code, 123)
1033             report.path_ignored(Path("wat"), "no match")
1034             self.assertEqual(len(out_lines), 0)
1035             self.assertEqual(len(err_lines), 2)
1036             self.assertEqual(
1037                 unstyle(str(report)),
1038                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1039                 " reformat.",
1040             )
1041             self.assertEqual(report.return_code, 123)
1042             report.done(Path("f4"), black.Changed.NO)
1043             self.assertEqual(len(out_lines), 0)
1044             self.assertEqual(len(err_lines), 2)
1045             self.assertEqual(
1046                 unstyle(str(report)),
1047                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1048                 " reformat.",
1049             )
1050             self.assertEqual(report.return_code, 123)
1051             report.check = True
1052             self.assertEqual(
1053                 unstyle(str(report)),
1054                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1055                 " would fail to reformat.",
1056             )
1057             report.check = False
1058             report.diff = True
1059             self.assertEqual(
1060                 unstyle(str(report)),
1061                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1062                 " would fail to reformat.",
1063             )
1064
1065     def test_report_normal(self) -> None:
1066         report = black.Report()
1067         out_lines = []
1068         err_lines = []
1069
1070         def out(msg: str, **kwargs: Any) -> None:
1071             out_lines.append(msg)
1072
1073         def err(msg: str, **kwargs: Any) -> None:
1074             err_lines.append(msg)
1075
1076         with patch("black.out", out), patch("black.err", err):
1077             report.done(Path("f1"), black.Changed.NO)
1078             self.assertEqual(len(out_lines), 0)
1079             self.assertEqual(len(err_lines), 0)
1080             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1081             self.assertEqual(report.return_code, 0)
1082             report.done(Path("f2"), black.Changed.YES)
1083             self.assertEqual(len(out_lines), 1)
1084             self.assertEqual(len(err_lines), 0)
1085             self.assertEqual(out_lines[-1], "reformatted f2")
1086             self.assertEqual(
1087                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1088             )
1089             report.done(Path("f3"), black.Changed.CACHED)
1090             self.assertEqual(len(out_lines), 1)
1091             self.assertEqual(len(err_lines), 0)
1092             self.assertEqual(out_lines[-1], "reformatted f2")
1093             self.assertEqual(
1094                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1095             )
1096             self.assertEqual(report.return_code, 0)
1097             report.check = True
1098             self.assertEqual(report.return_code, 1)
1099             report.check = False
1100             report.failed(Path("e1"), "boom")
1101             self.assertEqual(len(out_lines), 1)
1102             self.assertEqual(len(err_lines), 1)
1103             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1104             self.assertEqual(
1105                 unstyle(str(report)),
1106                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1107                 " reformat.",
1108             )
1109             self.assertEqual(report.return_code, 123)
1110             report.done(Path("f3"), black.Changed.YES)
1111             self.assertEqual(len(out_lines), 2)
1112             self.assertEqual(len(err_lines), 1)
1113             self.assertEqual(out_lines[-1], "reformatted f3")
1114             self.assertEqual(
1115                 unstyle(str(report)),
1116                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1117                 " reformat.",
1118             )
1119             self.assertEqual(report.return_code, 123)
1120             report.failed(Path("e2"), "boom")
1121             self.assertEqual(len(out_lines), 2)
1122             self.assertEqual(len(err_lines), 2)
1123             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1124             self.assertEqual(
1125                 unstyle(str(report)),
1126                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1127                 " reformat.",
1128             )
1129             self.assertEqual(report.return_code, 123)
1130             report.path_ignored(Path("wat"), "no match")
1131             self.assertEqual(len(out_lines), 2)
1132             self.assertEqual(len(err_lines), 2)
1133             self.assertEqual(
1134                 unstyle(str(report)),
1135                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1136                 " reformat.",
1137             )
1138             self.assertEqual(report.return_code, 123)
1139             report.done(Path("f4"), black.Changed.NO)
1140             self.assertEqual(len(out_lines), 2)
1141             self.assertEqual(len(err_lines), 2)
1142             self.assertEqual(
1143                 unstyle(str(report)),
1144                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1145                 " reformat.",
1146             )
1147             self.assertEqual(report.return_code, 123)
1148             report.check = True
1149             self.assertEqual(
1150                 unstyle(str(report)),
1151                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1152                 " would fail to reformat.",
1153             )
1154             report.check = False
1155             report.diff = True
1156             self.assertEqual(
1157                 unstyle(str(report)),
1158                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1159                 " would fail to reformat.",
1160             )
1161
1162     def test_lib2to3_parse(self) -> None:
1163         with self.assertRaises(black.InvalidInput):
1164             black.lib2to3_parse("invalid syntax")
1165
1166         straddling = "x + y"
1167         black.lib2to3_parse(straddling)
1168         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1169         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1170         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1171
1172         py2_only = "print x"
1173         black.lib2to3_parse(py2_only)
1174         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1175         with self.assertRaises(black.InvalidInput):
1176             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1177         with self.assertRaises(black.InvalidInput):
1178             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1179
1180         py3_only = "exec(x, end=y)"
1181         black.lib2to3_parse(py3_only)
1182         with self.assertRaises(black.InvalidInput):
1183             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1184         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1185         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1186
1187     def test_get_features_used(self) -> None:
1188         node = black.lib2to3_parse("def f(*, arg): ...\n")
1189         self.assertEqual(black.get_features_used(node), set())
1190         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1191         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1192         node = black.lib2to3_parse("f(*arg,)\n")
1193         self.assertEqual(
1194             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1195         )
1196         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1197         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1198         node = black.lib2to3_parse("123_456\n")
1199         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1200         node = black.lib2to3_parse("123456\n")
1201         self.assertEqual(black.get_features_used(node), set())
1202         source, expected = read_data("function")
1203         node = black.lib2to3_parse(source)
1204         expected_features = {
1205             Feature.TRAILING_COMMA_IN_CALL,
1206             Feature.TRAILING_COMMA_IN_DEF,
1207             Feature.F_STRINGS,
1208         }
1209         self.assertEqual(black.get_features_used(node), expected_features)
1210         node = black.lib2to3_parse(expected)
1211         self.assertEqual(black.get_features_used(node), expected_features)
1212         source, expected = read_data("expression")
1213         node = black.lib2to3_parse(source)
1214         self.assertEqual(black.get_features_used(node), set())
1215         node = black.lib2to3_parse(expected)
1216         self.assertEqual(black.get_features_used(node), set())
1217
1218     def test_get_future_imports(self) -> None:
1219         node = black.lib2to3_parse("\n")
1220         self.assertEqual(set(), black.get_future_imports(node))
1221         node = black.lib2to3_parse("from __future__ import black\n")
1222         self.assertEqual({"black"}, black.get_future_imports(node))
1223         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1224         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1225         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1226         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1227         node = black.lib2to3_parse(
1228             "from __future__ import multiple\nfrom __future__ import imports\n"
1229         )
1230         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1231         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1232         self.assertEqual({"black"}, black.get_future_imports(node))
1233         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1234         self.assertEqual({"black"}, black.get_future_imports(node))
1235         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1236         self.assertEqual(set(), black.get_future_imports(node))
1237         node = black.lib2to3_parse("from some.module import black\n")
1238         self.assertEqual(set(), black.get_future_imports(node))
1239         node = black.lib2to3_parse(
1240             "from __future__ import unicode_literals as _unicode_literals"
1241         )
1242         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1243         node = black.lib2to3_parse(
1244             "from __future__ import unicode_literals as _lol, print"
1245         )
1246         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1247
1248     def test_debug_visitor(self) -> None:
1249         source, _ = read_data("debug_visitor.py")
1250         expected, _ = read_data("debug_visitor.out")
1251         out_lines = []
1252         err_lines = []
1253
1254         def out(msg: str, **kwargs: Any) -> None:
1255             out_lines.append(msg)
1256
1257         def err(msg: str, **kwargs: Any) -> None:
1258             err_lines.append(msg)
1259
1260         with patch("black.out", out), patch("black.err", err):
1261             black.DebugVisitor.show(source)
1262         actual = "\n".join(out_lines) + "\n"
1263         log_name = ""
1264         if expected != actual:
1265             log_name = black.dump_to_file(*out_lines)
1266         self.assertEqual(
1267             expected,
1268             actual,
1269             f"AST print out is different. Actual version dumped to {log_name}",
1270         )
1271
1272     def test_format_file_contents(self) -> None:
1273         empty = ""
1274         mode = DEFAULT_MODE
1275         with self.assertRaises(black.NothingChanged):
1276             black.format_file_contents(empty, mode=mode, fast=False)
1277         just_nl = "\n"
1278         with self.assertRaises(black.NothingChanged):
1279             black.format_file_contents(just_nl, mode=mode, fast=False)
1280         same = "j = [1, 2, 3]\n"
1281         with self.assertRaises(black.NothingChanged):
1282             black.format_file_contents(same, mode=mode, fast=False)
1283         different = "j = [1,2,3]"
1284         expected = same
1285         actual = black.format_file_contents(different, mode=mode, fast=False)
1286         self.assertEqual(expected, actual)
1287         invalid = "return if you can"
1288         with self.assertRaises(black.InvalidInput) as e:
1289             black.format_file_contents(invalid, mode=mode, fast=False)
1290         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1291
1292     def test_endmarker(self) -> None:
1293         n = black.lib2to3_parse("\n")
1294         self.assertEqual(n.type, black.syms.file_input)
1295         self.assertEqual(len(n.children), 1)
1296         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1297
1298     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1299     def test_assertFormatEqual(self) -> None:
1300         out_lines = []
1301         err_lines = []
1302
1303         def out(msg: str, **kwargs: Any) -> None:
1304             out_lines.append(msg)
1305
1306         def err(msg: str, **kwargs: Any) -> None:
1307             err_lines.append(msg)
1308
1309         with patch("black.out", out), patch("black.err", err):
1310             with self.assertRaises(AssertionError):
1311                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1312
1313         out_str = "".join(out_lines)
1314         self.assertTrue("Expected tree:" in out_str)
1315         self.assertTrue("Actual tree:" in out_str)
1316         self.assertEqual("".join(err_lines), "")
1317
1318     def test_cache_broken_file(self) -> None:
1319         mode = DEFAULT_MODE
1320         with cache_dir() as workspace:
1321             cache_file = black.get_cache_file(mode)
1322             with cache_file.open("w") as fobj:
1323                 fobj.write("this is not a pickle")
1324             self.assertEqual(black.read_cache(mode), {})
1325             src = (workspace / "test.py").resolve()
1326             with src.open("w") as fobj:
1327                 fobj.write("print('hello')")
1328             self.invokeBlack([str(src)])
1329             cache = black.read_cache(mode)
1330             self.assertIn(src, cache)
1331
1332     def test_cache_single_file_already_cached(self) -> None:
1333         mode = DEFAULT_MODE
1334         with cache_dir() as workspace:
1335             src = (workspace / "test.py").resolve()
1336             with src.open("w") as fobj:
1337                 fobj.write("print('hello')")
1338             black.write_cache({}, [src], mode)
1339             self.invokeBlack([str(src)])
1340             with src.open("r") as fobj:
1341                 self.assertEqual(fobj.read(), "print('hello')")
1342
1343     @event_loop()
1344     def test_cache_multiple_files(self) -> None:
1345         mode = DEFAULT_MODE
1346         with cache_dir() as workspace, patch(
1347             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1348         ):
1349             one = (workspace / "one.py").resolve()
1350             with one.open("w") as fobj:
1351                 fobj.write("print('hello')")
1352             two = (workspace / "two.py").resolve()
1353             with two.open("w") as fobj:
1354                 fobj.write("print('hello')")
1355             black.write_cache({}, [one], mode)
1356             self.invokeBlack([str(workspace)])
1357             with one.open("r") as fobj:
1358                 self.assertEqual(fobj.read(), "print('hello')")
1359             with two.open("r") as fobj:
1360                 self.assertEqual(fobj.read(), 'print("hello")\n')
1361             cache = black.read_cache(mode)
1362             self.assertIn(one, cache)
1363             self.assertIn(two, cache)
1364
1365     def test_no_cache_when_writeback_diff(self) -> None:
1366         mode = DEFAULT_MODE
1367         with cache_dir() as workspace:
1368             src = (workspace / "test.py").resolve()
1369             with src.open("w") as fobj:
1370                 fobj.write("print('hello')")
1371             self.invokeBlack([str(src), "--diff"])
1372             cache_file = black.get_cache_file(mode)
1373             self.assertFalse(cache_file.exists())
1374
1375     def test_no_cache_when_stdin(self) -> None:
1376         mode = DEFAULT_MODE
1377         with cache_dir():
1378             result = CliRunner().invoke(
1379                 black.main, ["-"], input=BytesIO(b"print('hello')")
1380             )
1381             self.assertEqual(result.exit_code, 0)
1382             cache_file = black.get_cache_file(mode)
1383             self.assertFalse(cache_file.exists())
1384
1385     def test_read_cache_no_cachefile(self) -> None:
1386         mode = DEFAULT_MODE
1387         with cache_dir():
1388             self.assertEqual(black.read_cache(mode), {})
1389
1390     def test_write_cache_read_cache(self) -> None:
1391         mode = DEFAULT_MODE
1392         with cache_dir() as workspace:
1393             src = (workspace / "test.py").resolve()
1394             src.touch()
1395             black.write_cache({}, [src], mode)
1396             cache = black.read_cache(mode)
1397             self.assertIn(src, cache)
1398             self.assertEqual(cache[src], black.get_cache_info(src))
1399
1400     def test_filter_cached(self) -> None:
1401         with TemporaryDirectory() as workspace:
1402             path = Path(workspace)
1403             uncached = (path / "uncached").resolve()
1404             cached = (path / "cached").resolve()
1405             cached_but_changed = (path / "changed").resolve()
1406             uncached.touch()
1407             cached.touch()
1408             cached_but_changed.touch()
1409             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1410             todo, done = black.filter_cached(
1411                 cache, {uncached, cached, cached_but_changed}
1412             )
1413             self.assertEqual(todo, {uncached, cached_but_changed})
1414             self.assertEqual(done, {cached})
1415
1416     def test_write_cache_creates_directory_if_needed(self) -> None:
1417         mode = DEFAULT_MODE
1418         with cache_dir(exists=False) as workspace:
1419             self.assertFalse(workspace.exists())
1420             black.write_cache({}, [], mode)
1421             self.assertTrue(workspace.exists())
1422
1423     @event_loop()
1424     def test_failed_formatting_does_not_get_cached(self) -> None:
1425         mode = DEFAULT_MODE
1426         with cache_dir() as workspace, patch(
1427             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1428         ):
1429             failing = (workspace / "failing.py").resolve()
1430             with failing.open("w") as fobj:
1431                 fobj.write("not actually python")
1432             clean = (workspace / "clean.py").resolve()
1433             with clean.open("w") as fobj:
1434                 fobj.write('print("hello")\n')
1435             self.invokeBlack([str(workspace)], exit_code=123)
1436             cache = black.read_cache(mode)
1437             self.assertNotIn(failing, cache)
1438             self.assertIn(clean, cache)
1439
1440     def test_write_cache_write_fail(self) -> None:
1441         mode = DEFAULT_MODE
1442         with cache_dir(), patch.object(Path, "open") as mock:
1443             mock.side_effect = OSError
1444             black.write_cache({}, [], mode)
1445
1446     @event_loop()
1447     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1448     def test_works_in_mono_process_only_environment(self) -> None:
1449         with cache_dir() as workspace:
1450             for f in [
1451                 (workspace / "one.py").resolve(),
1452                 (workspace / "two.py").resolve(),
1453             ]:
1454                 f.write_text('print("hello")\n')
1455             self.invokeBlack([str(workspace)])
1456
1457     @event_loop()
1458     def test_check_diff_use_together(self) -> None:
1459         with cache_dir():
1460             # Files which will be reformatted.
1461             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1462             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1463             # Files which will not be reformatted.
1464             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1465             self.invokeBlack([str(src2), "--diff", "--check"])
1466             # Multi file command.
1467             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1468
1469     def test_no_files(self) -> None:
1470         with cache_dir():
1471             # Without an argument, black exits with error code 0.
1472             self.invokeBlack([])
1473
1474     def test_broken_symlink(self) -> None:
1475         with cache_dir() as workspace:
1476             symlink = workspace / "broken_link.py"
1477             try:
1478                 symlink.symlink_to("nonexistent.py")
1479             except OSError as e:
1480                 self.skipTest(f"Can't create symlinks: {e}")
1481             self.invokeBlack([str(workspace.resolve())])
1482
1483     def test_read_cache_line_lengths(self) -> None:
1484         mode = DEFAULT_MODE
1485         short_mode = replace(DEFAULT_MODE, line_length=1)
1486         with cache_dir() as workspace:
1487             path = (workspace / "file.py").resolve()
1488             path.touch()
1489             black.write_cache({}, [path], mode)
1490             one = black.read_cache(mode)
1491             self.assertIn(path, one)
1492             two = black.read_cache(short_mode)
1493             self.assertNotIn(path, two)
1494
1495     def test_tricky_unicode_symbols(self) -> None:
1496         source, expected = read_data("tricky_unicode_symbols")
1497         actual = fs(source)
1498         self.assertFormatEqual(expected, actual)
1499         black.assert_equivalent(source, actual)
1500         black.assert_stable(source, actual, DEFAULT_MODE)
1501
1502     def test_single_file_force_pyi(self) -> None:
1503         reg_mode = DEFAULT_MODE
1504         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1505         contents, expected = read_data("force_pyi")
1506         with cache_dir() as workspace:
1507             path = (workspace / "file.py").resolve()
1508             with open(path, "w") as fh:
1509                 fh.write(contents)
1510             self.invokeBlack([str(path), "--pyi"])
1511             with open(path, "r") as fh:
1512                 actual = fh.read()
1513             # verify cache with --pyi is separate
1514             pyi_cache = black.read_cache(pyi_mode)
1515             self.assertIn(path, pyi_cache)
1516             normal_cache = black.read_cache(reg_mode)
1517             self.assertNotIn(path, normal_cache)
1518         self.assertEqual(actual, expected)
1519
1520     @event_loop()
1521     def test_multi_file_force_pyi(self) -> None:
1522         reg_mode = DEFAULT_MODE
1523         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1524         contents, expected = read_data("force_pyi")
1525         with cache_dir() as workspace:
1526             paths = [
1527                 (workspace / "file1.py").resolve(),
1528                 (workspace / "file2.py").resolve(),
1529             ]
1530             for path in paths:
1531                 with open(path, "w") as fh:
1532                     fh.write(contents)
1533             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1534             for path in paths:
1535                 with open(path, "r") as fh:
1536                     actual = fh.read()
1537                 self.assertEqual(actual, expected)
1538             # verify cache with --pyi is separate
1539             pyi_cache = black.read_cache(pyi_mode)
1540             normal_cache = black.read_cache(reg_mode)
1541             for path in paths:
1542                 self.assertIn(path, pyi_cache)
1543                 self.assertNotIn(path, normal_cache)
1544
1545     def test_pipe_force_pyi(self) -> None:
1546         source, expected = read_data("force_pyi")
1547         result = CliRunner().invoke(
1548             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1549         )
1550         self.assertEqual(result.exit_code, 0)
1551         actual = result.output
1552         self.assertFormatEqual(actual, expected)
1553
1554     def test_single_file_force_py36(self) -> None:
1555         reg_mode = DEFAULT_MODE
1556         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1557         source, expected = read_data("force_py36")
1558         with cache_dir() as workspace:
1559             path = (workspace / "file.py").resolve()
1560             with open(path, "w") as fh:
1561                 fh.write(source)
1562             self.invokeBlack([str(path), *PY36_ARGS])
1563             with open(path, "r") as fh:
1564                 actual = fh.read()
1565             # verify cache with --target-version is separate
1566             py36_cache = black.read_cache(py36_mode)
1567             self.assertIn(path, py36_cache)
1568             normal_cache = black.read_cache(reg_mode)
1569             self.assertNotIn(path, normal_cache)
1570         self.assertEqual(actual, expected)
1571
1572     @event_loop()
1573     def test_multi_file_force_py36(self) -> None:
1574         reg_mode = DEFAULT_MODE
1575         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1576         source, expected = read_data("force_py36")
1577         with cache_dir() as workspace:
1578             paths = [
1579                 (workspace / "file1.py").resolve(),
1580                 (workspace / "file2.py").resolve(),
1581             ]
1582             for path in paths:
1583                 with open(path, "w") as fh:
1584                     fh.write(source)
1585             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1586             for path in paths:
1587                 with open(path, "r") as fh:
1588                     actual = fh.read()
1589                 self.assertEqual(actual, expected)
1590             # verify cache with --target-version is separate
1591             pyi_cache = black.read_cache(py36_mode)
1592             normal_cache = black.read_cache(reg_mode)
1593             for path in paths:
1594                 self.assertIn(path, pyi_cache)
1595                 self.assertNotIn(path, normal_cache)
1596
1597     def test_collections(self) -> None:
1598         source, expected = read_data("collections")
1599         actual = fs(source)
1600         self.assertFormatEqual(expected, actual)
1601         black.assert_equivalent(source, actual)
1602         black.assert_stable(source, actual, DEFAULT_MODE)
1603
1604     def test_pipe_force_py36(self) -> None:
1605         source, expected = read_data("force_py36")
1606         result = CliRunner().invoke(
1607             black.main,
1608             ["-", "-q", "--target-version=py36"],
1609             input=BytesIO(source.encode("utf8")),
1610         )
1611         self.assertEqual(result.exit_code, 0)
1612         actual = result.output
1613         self.assertFormatEqual(actual, expected)
1614
1615     def test_include_exclude(self) -> None:
1616         path = THIS_DIR / "data" / "include_exclude_tests"
1617         include = re.compile(r"\.pyi?$")
1618         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1619         report = black.Report()
1620         gitignore = PathSpec.from_lines("gitwildmatch", [])
1621         sources: List[Path] = []
1622         expected = [
1623             Path(path / "b/dont_exclude/a.py"),
1624             Path(path / "b/dont_exclude/a.pyi"),
1625         ]
1626         this_abs = THIS_DIR.resolve()
1627         sources.extend(
1628             black.gen_python_files(
1629                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1630             )
1631         )
1632         self.assertEqual(sorted(expected), sorted(sources))
1633
1634     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1635     def test_exclude_for_issue_1572(self) -> None:
1636         # Exclude shouldn't touch files that were explicitly given to Black through the
1637         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1638         # https://github.com/psf/black/issues/1572
1639         path = THIS_DIR / "data" / "include_exclude_tests"
1640         include = ""
1641         exclude = r"/exclude/|a\.py"
1642         src = str(path / "b/exclude/a.py")
1643         report = black.Report()
1644         expected = [Path(path / "b/exclude/a.py")]
1645         sources = list(
1646             black.get_sources(
1647                 ctx=FakeContext(),
1648                 src=(src,),
1649                 quiet=True,
1650                 verbose=False,
1651                 include=include,
1652                 exclude=exclude,
1653                 force_exclude=None,
1654                 report=report,
1655             )
1656         )
1657         self.assertEqual(sorted(expected), sorted(sources))
1658
1659     def test_gitignore_exclude(self) -> None:
1660         path = THIS_DIR / "data" / "include_exclude_tests"
1661         include = re.compile(r"\.pyi?$")
1662         exclude = re.compile(r"")
1663         report = black.Report()
1664         gitignore = PathSpec.from_lines(
1665             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1666         )
1667         sources: List[Path] = []
1668         expected = [
1669             Path(path / "b/dont_exclude/a.py"),
1670             Path(path / "b/dont_exclude/a.pyi"),
1671         ]
1672         this_abs = THIS_DIR.resolve()
1673         sources.extend(
1674             black.gen_python_files(
1675                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1676             )
1677         )
1678         self.assertEqual(sorted(expected), sorted(sources))
1679
1680     def test_empty_include(self) -> None:
1681         path = THIS_DIR / "data" / "include_exclude_tests"
1682         report = black.Report()
1683         gitignore = PathSpec.from_lines("gitwildmatch", [])
1684         empty = re.compile(r"")
1685         sources: List[Path] = []
1686         expected = [
1687             Path(path / "b/exclude/a.pie"),
1688             Path(path / "b/exclude/a.py"),
1689             Path(path / "b/exclude/a.pyi"),
1690             Path(path / "b/dont_exclude/a.pie"),
1691             Path(path / "b/dont_exclude/a.py"),
1692             Path(path / "b/dont_exclude/a.pyi"),
1693             Path(path / "b/.definitely_exclude/a.pie"),
1694             Path(path / "b/.definitely_exclude/a.py"),
1695             Path(path / "b/.definitely_exclude/a.pyi"),
1696         ]
1697         this_abs = THIS_DIR.resolve()
1698         sources.extend(
1699             black.gen_python_files(
1700                 path.iterdir(),
1701                 this_abs,
1702                 empty,
1703                 re.compile(black.DEFAULT_EXCLUDES),
1704                 None,
1705                 report,
1706                 gitignore,
1707             )
1708         )
1709         self.assertEqual(sorted(expected), sorted(sources))
1710
1711     def test_empty_exclude(self) -> None:
1712         path = THIS_DIR / "data" / "include_exclude_tests"
1713         report = black.Report()
1714         gitignore = PathSpec.from_lines("gitwildmatch", [])
1715         empty = re.compile(r"")
1716         sources: List[Path] = []
1717         expected = [
1718             Path(path / "b/dont_exclude/a.py"),
1719             Path(path / "b/dont_exclude/a.pyi"),
1720             Path(path / "b/exclude/a.py"),
1721             Path(path / "b/exclude/a.pyi"),
1722             Path(path / "b/.definitely_exclude/a.py"),
1723             Path(path / "b/.definitely_exclude/a.pyi"),
1724         ]
1725         this_abs = THIS_DIR.resolve()
1726         sources.extend(
1727             black.gen_python_files(
1728                 path.iterdir(),
1729                 this_abs,
1730                 re.compile(black.DEFAULT_INCLUDES),
1731                 empty,
1732                 None,
1733                 report,
1734                 gitignore,
1735             )
1736         )
1737         self.assertEqual(sorted(expected), sorted(sources))
1738
1739     def test_invalid_include_exclude(self) -> None:
1740         for option in ["--include", "--exclude"]:
1741             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1742
1743     def test_preserves_line_endings(self) -> None:
1744         with TemporaryDirectory() as workspace:
1745             test_file = Path(workspace) / "test.py"
1746             for nl in ["\n", "\r\n"]:
1747                 contents = nl.join(["def f(  ):", "    pass"])
1748                 test_file.write_bytes(contents.encode())
1749                 ff(test_file, write_back=black.WriteBack.YES)
1750                 updated_contents: bytes = test_file.read_bytes()
1751                 self.assertIn(nl.encode(), updated_contents)
1752                 if nl == "\n":
1753                     self.assertNotIn(b"\r\n", updated_contents)
1754
1755     def test_preserves_line_endings_via_stdin(self) -> None:
1756         for nl in ["\n", "\r\n"]:
1757             contents = nl.join(["def f(  ):", "    pass"])
1758             runner = BlackRunner()
1759             result = runner.invoke(
1760                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1761             )
1762             self.assertEqual(result.exit_code, 0)
1763             output = runner.stdout_bytes
1764             self.assertIn(nl.encode("utf8"), output)
1765             if nl == "\n":
1766                 self.assertNotIn(b"\r\n", output)
1767
1768     def test_assert_equivalent_different_asts(self) -> None:
1769         with self.assertRaises(AssertionError):
1770             black.assert_equivalent("{}", "None")
1771
1772     def test_symlink_out_of_root_directory(self) -> None:
1773         path = MagicMock()
1774         root = THIS_DIR.resolve()
1775         child = MagicMock()
1776         include = re.compile(black.DEFAULT_INCLUDES)
1777         exclude = re.compile(black.DEFAULT_EXCLUDES)
1778         report = black.Report()
1779         gitignore = PathSpec.from_lines("gitwildmatch", [])
1780         # `child` should behave like a symlink which resolved path is clearly
1781         # outside of the `root` directory.
1782         path.iterdir.return_value = [child]
1783         child.resolve.return_value = Path("/a/b/c")
1784         child.as_posix.return_value = "/a/b/c"
1785         child.is_symlink.return_value = True
1786         try:
1787             list(
1788                 black.gen_python_files(
1789                     path.iterdir(), root, include, exclude, None, report, gitignore
1790                 )
1791             )
1792         except ValueError as ve:
1793             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1794         path.iterdir.assert_called_once()
1795         child.resolve.assert_called_once()
1796         child.is_symlink.assert_called_once()
1797         # `child` should behave like a strange file which resolved path is clearly
1798         # outside of the `root` directory.
1799         child.is_symlink.return_value = False
1800         with self.assertRaises(ValueError):
1801             list(
1802                 black.gen_python_files(
1803                     path.iterdir(), root, include, exclude, None, report, gitignore
1804                 )
1805             )
1806         path.iterdir.assert_called()
1807         self.assertEqual(path.iterdir.call_count, 2)
1808         child.resolve.assert_called()
1809         self.assertEqual(child.resolve.call_count, 2)
1810         child.is_symlink.assert_called()
1811         self.assertEqual(child.is_symlink.call_count, 2)
1812
1813     def test_shhh_click(self) -> None:
1814         try:
1815             from click import _unicodefun  # type: ignore
1816         except ModuleNotFoundError:
1817             self.skipTest("Incompatible Click version")
1818         if not hasattr(_unicodefun, "_verify_python3_env"):
1819             self.skipTest("Incompatible Click version")
1820         # First, let's see if Click is crashing with a preferred ASCII charset.
1821         with patch("locale.getpreferredencoding") as gpe:
1822             gpe.return_value = "ASCII"
1823             with self.assertRaises(RuntimeError):
1824                 _unicodefun._verify_python3_env()
1825         # Now, let's silence Click...
1826         black.patch_click()
1827         # ...and confirm it's silent.
1828         with patch("locale.getpreferredencoding") as gpe:
1829             gpe.return_value = "ASCII"
1830             try:
1831                 _unicodefun._verify_python3_env()
1832             except RuntimeError as re:
1833                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1834
1835     def test_root_logger_not_used_directly(self) -> None:
1836         def fail(*args: Any, **kwargs: Any) -> None:
1837             self.fail("Record created with root logger")
1838
1839         with patch.multiple(
1840             logging.root,
1841             debug=fail,
1842             info=fail,
1843             warning=fail,
1844             error=fail,
1845             critical=fail,
1846             log=fail,
1847         ):
1848             ff(THIS_FILE)
1849
1850     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1851     def test_blackd_main(self) -> None:
1852         with patch("blackd.web.run_app"):
1853             result = CliRunner().invoke(blackd.main, [])
1854             if result.exception is not None:
1855                 raise result.exception
1856             self.assertEqual(result.exit_code, 0)
1857
1858     def test_invalid_config_return_code(self) -> None:
1859         tmp_file = Path(black.dump_to_file())
1860         try:
1861             tmp_config = Path(black.dump_to_file())
1862             tmp_config.unlink()
1863             args = ["--config", str(tmp_config), str(tmp_file)]
1864             self.invokeBlack(args, exit_code=2, ignore_config=False)
1865         finally:
1866             tmp_file.unlink()
1867
1868     def test_parse_pyproject_toml(self) -> None:
1869         test_toml_file = THIS_DIR / "test.toml"
1870         config = black.parse_pyproject_toml(str(test_toml_file))
1871         self.assertEqual(config["verbose"], 1)
1872         self.assertEqual(config["check"], "no")
1873         self.assertEqual(config["diff"], "y")
1874         self.assertEqual(config["color"], True)
1875         self.assertEqual(config["line_length"], 79)
1876         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1877         self.assertEqual(config["exclude"], r"\.pyi?$")
1878         self.assertEqual(config["include"], r"\.py?$")
1879
1880     def test_read_pyproject_toml(self) -> None:
1881         test_toml_file = THIS_DIR / "test.toml"
1882         fake_ctx = FakeContext()
1883         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1884         config = fake_ctx.default_map
1885         self.assertEqual(config["verbose"], "1")
1886         self.assertEqual(config["check"], "no")
1887         self.assertEqual(config["diff"], "y")
1888         self.assertEqual(config["color"], "True")
1889         self.assertEqual(config["line_length"], "79")
1890         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1891         self.assertEqual(config["exclude"], r"\.pyi?$")
1892         self.assertEqual(config["include"], r"\.py?$")
1893
1894     def test_find_project_root(self) -> None:
1895         with TemporaryDirectory() as workspace:
1896             root = Path(workspace)
1897             test_dir = root / "test"
1898             test_dir.mkdir()
1899
1900             src_dir = root / "src"
1901             src_dir.mkdir()
1902
1903             root_pyproject = root / "pyproject.toml"
1904             root_pyproject.touch()
1905             src_pyproject = src_dir / "pyproject.toml"
1906             src_pyproject.touch()
1907             src_python = src_dir / "foo.py"
1908             src_python.touch()
1909
1910             self.assertEqual(
1911                 black.find_project_root((src_dir, test_dir)), root.resolve()
1912             )
1913             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1914             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1915
1916
1917 class BlackDTestCase(AioHTTPTestCase):
1918     async def get_application(self) -> web.Application:
1919         return blackd.make_app()
1920
1921     # TODO: remove these decorators once the below is released
1922     # https://github.com/aio-libs/aiohttp/pull/3727
1923     @skip_if_exception("ClientOSError")
1924     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1925     @unittest_run_loop
1926     async def test_blackd_request_needs_formatting(self) -> None:
1927         response = await self.client.post("/", data=b"print('hello world')")
1928         self.assertEqual(response.status, 200)
1929         self.assertEqual(response.charset, "utf8")
1930         self.assertEqual(await response.read(), b'print("hello world")\n')
1931
1932     @skip_if_exception("ClientOSError")
1933     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1934     @unittest_run_loop
1935     async def test_blackd_request_no_change(self) -> None:
1936         response = await self.client.post("/", data=b'print("hello world")\n')
1937         self.assertEqual(response.status, 204)
1938         self.assertEqual(await response.read(), b"")
1939
1940     @skip_if_exception("ClientOSError")
1941     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1942     @unittest_run_loop
1943     async def test_blackd_request_syntax_error(self) -> None:
1944         response = await self.client.post("/", data=b"what even ( is")
1945         self.assertEqual(response.status, 400)
1946         content = await response.text()
1947         self.assertTrue(
1948             content.startswith("Cannot parse"),
1949             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1950         )
1951
1952     @skip_if_exception("ClientOSError")
1953     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1954     @unittest_run_loop
1955     async def test_blackd_unsupported_version(self) -> None:
1956         response = await self.client.post(
1957             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1958         )
1959         self.assertEqual(response.status, 501)
1960
1961     @skip_if_exception("ClientOSError")
1962     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1963     @unittest_run_loop
1964     async def test_blackd_supported_version(self) -> None:
1965         response = await self.client.post(
1966             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1967         )
1968         self.assertEqual(response.status, 200)
1969
1970     @skip_if_exception("ClientOSError")
1971     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1972     @unittest_run_loop
1973     async def test_blackd_invalid_python_variant(self) -> None:
1974         async def check(header_value: str, expected_status: int = 400) -> None:
1975             response = await self.client.post(
1976                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1977             )
1978             self.assertEqual(response.status, expected_status)
1979
1980         await check("lol")
1981         await check("ruby3.5")
1982         await check("pyi3.6")
1983         await check("py1.5")
1984         await check("2.8")
1985         await check("py2.8")
1986         await check("3.0")
1987         await check("pypy3.0")
1988         await check("jython3.4")
1989
1990     @skip_if_exception("ClientOSError")
1991     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1992     @unittest_run_loop
1993     async def test_blackd_pyi(self) -> None:
1994         source, expected = read_data("stub.pyi")
1995         response = await self.client.post(
1996             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1997         )
1998         self.assertEqual(response.status, 200)
1999         self.assertEqual(await response.text(), expected)
2000
2001     @skip_if_exception("ClientOSError")
2002     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2003     @unittest_run_loop
2004     async def test_blackd_diff(self) -> None:
2005         diff_header = re.compile(
2006             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2007         )
2008
2009         source, _ = read_data("blackd_diff.py")
2010         expected, _ = read_data("blackd_diff.diff")
2011
2012         response = await self.client.post(
2013             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2014         )
2015         self.assertEqual(response.status, 200)
2016
2017         actual = await response.text()
2018         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2019         self.assertEqual(actual, expected)
2020
2021     @skip_if_exception("ClientOSError")
2022     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2023     @unittest_run_loop
2024     async def test_blackd_python_variant(self) -> None:
2025         code = (
2026             "def f(\n"
2027             "    and_has_a_bunch_of,\n"
2028             "    very_long_arguments_too,\n"
2029             "    and_lots_of_them_as_well_lol,\n"
2030             "    **and_very_long_keyword_arguments\n"
2031             "):\n"
2032             "    pass\n"
2033         )
2034
2035         async def check(header_value: str, expected_status: int) -> None:
2036             response = await self.client.post(
2037                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2038             )
2039             self.assertEqual(
2040                 response.status, expected_status, msg=await response.text()
2041             )
2042
2043         await check("3.6", 200)
2044         await check("py3.6", 200)
2045         await check("3.6,3.7", 200)
2046         await check("3.6,py3.7", 200)
2047         await check("py36,py37", 200)
2048         await check("36", 200)
2049         await check("3.6.4", 200)
2050
2051         await check("2", 204)
2052         await check("2.7", 204)
2053         await check("py2.7", 204)
2054         await check("3.4", 204)
2055         await check("py3.4", 204)
2056         await check("py34,py36", 204)
2057         await check("34", 204)
2058
2059     @skip_if_exception("ClientOSError")
2060     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2061     @unittest_run_loop
2062     async def test_blackd_line_length(self) -> None:
2063         response = await self.client.post(
2064             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2065         )
2066         self.assertEqual(response.status, 200)
2067
2068     @skip_if_exception("ClientOSError")
2069     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2070     @unittest_run_loop
2071     async def test_blackd_invalid_line_length(self) -> None:
2072         response = await self.client.post(
2073             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2074         )
2075         self.assertEqual(response.status, 400)
2076
2077     @skip_if_exception("ClientOSError")
2078     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2079     @unittest_run_loop
2080     async def test_blackd_response_black_version_header(self) -> None:
2081         response = await self.client.post("/")
2082         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2083
2084
2085 with open(black.__file__, "r", encoding="utf-8") as _bf:
2086     black_source_lines = _bf.readlines()
2087
2088
2089 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2090     """Show function calls `from black/__init__.py` as they happen.
2091
2092     Register this with `sys.settrace()` in a test you're debugging.
2093     """
2094     if event != "call":
2095         return tracefunc
2096
2097     stack = len(inspect.stack()) - 19
2098     stack *= 2
2099     filename = frame.f_code.co_filename
2100     lineno = frame.f_lineno
2101     func_sig_lineno = lineno - 1
2102     funcname = black_source_lines[func_sig_lineno].strip()
2103     while funcname.startswith("@"):
2104         func_sig_lineno += 1
2105         funcname = black_source_lines[func_sig_lineno].strip()
2106     if "black/__init__.py" in filename:
2107         print(f"{' ' * stack}{lineno}:{funcname}")
2108     return tracefunc
2109
2110
2111 if __name__ == "__main__":
2112     unittest.main(module="test_black")