]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

629afc5b0ad2c86fc340f4921bd5f24b3ba112da
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def _test_wip(self) -> None:
350         source, expected = read_data("wip")
351         sys.settrace(tracefunc)
352         mode = replace(
353             DEFAULT_MODE,
354             experimental_string_processing=False,
355             target_versions={black.TargetVersion.PY38},
356         )
357         actual = fs(source, mode=mode)
358         sys.settrace(None)
359         self.assertFormatEqual(expected, actual)
360         black.assert_equivalent(source, actual)
361         black.assert_stable(source, actual, black.FileMode())
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_function_trailing_comma(self) -> None:
365         source, expected = read_data("function_trailing_comma")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_equivalent(source, actual)
369         black.assert_stable(source, actual, DEFAULT_MODE)
370
371     @unittest.expectedFailure
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_trailing_comma_optional_parens_stability1(self) -> None:
374         source, _expected = read_data("trailing_comma_optional_parens1")
375         actual = fs(source)
376         black.assert_stable(source, actual, DEFAULT_MODE)
377
378     @unittest.expectedFailure
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_trailing_comma_optional_parens_stability2(self) -> None:
381         source, _expected = read_data("trailing_comma_optional_parens2")
382         actual = fs(source)
383         black.assert_stable(source, actual, DEFAULT_MODE)
384
385     @unittest.expectedFailure
386     @patch("black.dump_to_file", dump_to_stderr)
387     def test_trailing_comma_optional_parens_stability3(self) -> None:
388         source, _expected = read_data("trailing_comma_optional_parens3")
389         actual = fs(source)
390         black.assert_stable(source, actual, DEFAULT_MODE)
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_expression(self) -> None:
394         source, expected = read_data("expression")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, DEFAULT_MODE)
399
400     @patch("black.dump_to_file", dump_to_stderr)
401     def test_pep_572(self) -> None:
402         source, expected = read_data("pep_572")
403         actual = fs(source)
404         self.assertFormatEqual(expected, actual)
405         black.assert_stable(source, actual, DEFAULT_MODE)
406         if sys.version_info >= (3, 8):
407             black.assert_equivalent(source, actual)
408
409     def test_pep_572_version_detection(self) -> None:
410         source, _ = read_data("pep_572")
411         root = black.lib2to3_parse(source)
412         features = black.get_features_used(root)
413         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
414         versions = black.detect_target_versions(root)
415         self.assertIn(black.TargetVersion.PY38, versions)
416
417     def test_expression_ff(self) -> None:
418         source, expected = read_data("expression")
419         tmp_file = Path(black.dump_to_file(source))
420         try:
421             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
422             with open(tmp_file, encoding="utf8") as f:
423                 actual = f.read()
424         finally:
425             os.unlink(tmp_file)
426         self.assertFormatEqual(expected, actual)
427         with patch("black.dump_to_file", dump_to_stderr):
428             black.assert_equivalent(source, actual)
429             black.assert_stable(source, actual, DEFAULT_MODE)
430
431     def test_expression_diff(self) -> None:
432         source, _ = read_data("expression.py")
433         expected, _ = read_data("expression.diff")
434         tmp_file = Path(black.dump_to_file(source))
435         diff_header = re.compile(
436             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
437             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
438         )
439         try:
440             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
441             self.assertEqual(result.exit_code, 0)
442         finally:
443             os.unlink(tmp_file)
444         actual = result.output
445         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
446         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
447         if expected != actual:
448             dump = black.dump_to_file(actual)
449             msg = (
450                 "Expected diff isn't equal to the actual. If you made changes to"
451                 " expression.py and this is an anticipated difference, overwrite"
452                 f" tests/data/expression.diff with {dump}"
453             )
454             self.assertEqual(expected, actual, msg)
455
456     def test_expression_diff_with_color(self) -> None:
457         source, _ = read_data("expression.py")
458         expected, _ = read_data("expression.diff")
459         tmp_file = Path(black.dump_to_file(source))
460         try:
461             result = BlackRunner().invoke(
462                 black.main, ["--diff", "--color", str(tmp_file)]
463             )
464         finally:
465             os.unlink(tmp_file)
466         actual = result.output
467         # We check the contents of the diff in `test_expression_diff`. All
468         # we need to check here is that color codes exist in the result.
469         self.assertIn("\033[1;37m", actual)
470         self.assertIn("\033[36m", actual)
471         self.assertIn("\033[32m", actual)
472         self.assertIn("\033[31m", actual)
473         self.assertIn("\033[0m", actual)
474
475     @patch("black.dump_to_file", dump_to_stderr)
476     def test_fstring(self) -> None:
477         source, expected = read_data("fstring")
478         actual = fs(source)
479         self.assertFormatEqual(expected, actual)
480         black.assert_equivalent(source, actual)
481         black.assert_stable(source, actual, DEFAULT_MODE)
482
483     @patch("black.dump_to_file", dump_to_stderr)
484     def test_pep_570(self) -> None:
485         source, expected = read_data("pep_570")
486         actual = fs(source)
487         self.assertFormatEqual(expected, actual)
488         black.assert_stable(source, actual, DEFAULT_MODE)
489         if sys.version_info >= (3, 8):
490             black.assert_equivalent(source, actual)
491
492     def test_detect_pos_only_arguments(self) -> None:
493         source, _ = read_data("pep_570")
494         root = black.lib2to3_parse(source)
495         features = black.get_features_used(root)
496         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
497         versions = black.detect_target_versions(root)
498         self.assertIn(black.TargetVersion.PY38, versions)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_string_quotes(self) -> None:
502         source, expected = read_data("string_quotes")
503         actual = fs(source)
504         self.assertFormatEqual(expected, actual)
505         black.assert_equivalent(source, actual)
506         black.assert_stable(source, actual, DEFAULT_MODE)
507         mode = replace(DEFAULT_MODE, string_normalization=False)
508         not_normalized = fs(source, mode=mode)
509         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
510         black.assert_equivalent(source, not_normalized)
511         black.assert_stable(source, not_normalized, mode=mode)
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_docstring(self) -> None:
515         source, expected = read_data("docstring")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, DEFAULT_MODE)
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_docstring_no_string_normalization(self) -> None:
523         """Like test_docstring but with string normalization off."""
524         source, expected = read_data("docstring_no_string_normalization")
525         mode = replace(DEFAULT_MODE, string_normalization=False)
526         actual = fs(source, mode=mode)
527         self.assertFormatEqual(expected, actual)
528         black.assert_equivalent(source, actual)
529         black.assert_stable(source, actual, mode)
530
531     def test_long_strings(self) -> None:
532         """Tests for splitting long strings."""
533         source, expected = read_data("long_strings")
534         actual = fs(source)
535         self.assertFormatEqual(expected, actual)
536         black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, DEFAULT_MODE)
538
539     def test_long_strings_flag_disabled(self) -> None:
540         """Tests for turning off the string processing logic."""
541         source, expected = read_data("long_strings_flag_disabled")
542         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
543         actual = fs(source, mode=mode)
544         self.assertFormatEqual(expected, actual)
545         black.assert_stable(expected, actual, mode)
546
547     @patch("black.dump_to_file", dump_to_stderr)
548     def test_long_strings__edge_case(self) -> None:
549         """Edge-case tests for splitting long strings."""
550         source, expected = read_data("long_strings__edge_case")
551         actual = fs(source)
552         self.assertFormatEqual(expected, actual)
553         black.assert_equivalent(source, actual)
554         black.assert_stable(source, actual, DEFAULT_MODE)
555
556     @patch("black.dump_to_file", dump_to_stderr)
557     def test_long_strings__regression(self) -> None:
558         """Regression tests for splitting long strings."""
559         source, expected = read_data("long_strings__regression")
560         actual = fs(source)
561         self.assertFormatEqual(expected, actual)
562         black.assert_equivalent(source, actual)
563         black.assert_stable(source, actual, DEFAULT_MODE)
564
565     @patch("black.dump_to_file", dump_to_stderr)
566     def test_slices(self) -> None:
567         source, expected = read_data("slices")
568         actual = fs(source)
569         self.assertFormatEqual(expected, actual)
570         black.assert_equivalent(source, actual)
571         black.assert_stable(source, actual, DEFAULT_MODE)
572
573     @patch("black.dump_to_file", dump_to_stderr)
574     def test_percent_precedence(self) -> None:
575         source, expected = read_data("percent_precedence")
576         actual = fs(source)
577         self.assertFormatEqual(expected, actual)
578         black.assert_equivalent(source, actual)
579         black.assert_stable(source, actual, DEFAULT_MODE)
580
581     @patch("black.dump_to_file", dump_to_stderr)
582     def test_comments(self) -> None:
583         source, expected = read_data("comments")
584         actual = fs(source)
585         self.assertFormatEqual(expected, actual)
586         black.assert_equivalent(source, actual)
587         black.assert_stable(source, actual, DEFAULT_MODE)
588
589     @patch("black.dump_to_file", dump_to_stderr)
590     def test_comments2(self) -> None:
591         source, expected = read_data("comments2")
592         actual = fs(source)
593         self.assertFormatEqual(expected, actual)
594         black.assert_equivalent(source, actual)
595         black.assert_stable(source, actual, DEFAULT_MODE)
596
597     @patch("black.dump_to_file", dump_to_stderr)
598     def test_comments3(self) -> None:
599         source, expected = read_data("comments3")
600         actual = fs(source)
601         self.assertFormatEqual(expected, actual)
602         black.assert_equivalent(source, actual)
603         black.assert_stable(source, actual, DEFAULT_MODE)
604
605     @patch("black.dump_to_file", dump_to_stderr)
606     def test_comments4(self) -> None:
607         source, expected = read_data("comments4")
608         actual = fs(source)
609         self.assertFormatEqual(expected, actual)
610         black.assert_equivalent(source, actual)
611         black.assert_stable(source, actual, DEFAULT_MODE)
612
613     @patch("black.dump_to_file", dump_to_stderr)
614     def test_comments5(self) -> None:
615         source, expected = read_data("comments5")
616         actual = fs(source)
617         self.assertFormatEqual(expected, actual)
618         black.assert_equivalent(source, actual)
619         black.assert_stable(source, actual, DEFAULT_MODE)
620
621     @patch("black.dump_to_file", dump_to_stderr)
622     def test_comments6(self) -> None:
623         source, expected = read_data("comments6")
624         actual = fs(source)
625         self.assertFormatEqual(expected, actual)
626         black.assert_equivalent(source, actual)
627         black.assert_stable(source, actual, DEFAULT_MODE)
628
629     @patch("black.dump_to_file", dump_to_stderr)
630     def test_comments7(self) -> None:
631         source, expected = read_data("comments7")
632         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
633         actual = fs(source, mode=mode)
634         self.assertFormatEqual(expected, actual)
635         black.assert_equivalent(source, actual)
636         black.assert_stable(source, actual, DEFAULT_MODE)
637
638     @patch("black.dump_to_file", dump_to_stderr)
639     def test_comment_after_escaped_newline(self) -> None:
640         source, expected = read_data("comment_after_escaped_newline")
641         actual = fs(source)
642         self.assertFormatEqual(expected, actual)
643         black.assert_equivalent(source, actual)
644         black.assert_stable(source, actual, DEFAULT_MODE)
645
646     @patch("black.dump_to_file", dump_to_stderr)
647     def test_cantfit(self) -> None:
648         source, expected = read_data("cantfit")
649         actual = fs(source)
650         self.assertFormatEqual(expected, actual)
651         black.assert_equivalent(source, actual)
652         black.assert_stable(source, actual, DEFAULT_MODE)
653
654     @patch("black.dump_to_file", dump_to_stderr)
655     def test_import_spacing(self) -> None:
656         source, expected = read_data("import_spacing")
657         actual = fs(source)
658         self.assertFormatEqual(expected, actual)
659         black.assert_equivalent(source, actual)
660         black.assert_stable(source, actual, DEFAULT_MODE)
661
662     @patch("black.dump_to_file", dump_to_stderr)
663     def test_composition(self) -> None:
664         source, expected = read_data("composition")
665         actual = fs(source)
666         self.assertFormatEqual(expected, actual)
667         black.assert_equivalent(source, actual)
668         black.assert_stable(source, actual, DEFAULT_MODE)
669
670     @patch("black.dump_to_file", dump_to_stderr)
671     def test_composition_no_trailing_comma(self) -> None:
672         source, expected = read_data("composition_no_trailing_comma")
673         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
674         actual = fs(source, mode=mode)
675         self.assertFormatEqual(expected, actual)
676         black.assert_equivalent(source, actual)
677         black.assert_stable(source, actual, DEFAULT_MODE)
678
679     @patch("black.dump_to_file", dump_to_stderr)
680     def test_empty_lines(self) -> None:
681         source, expected = read_data("empty_lines")
682         actual = fs(source)
683         self.assertFormatEqual(expected, actual)
684         black.assert_equivalent(source, actual)
685         black.assert_stable(source, actual, DEFAULT_MODE)
686
687     @patch("black.dump_to_file", dump_to_stderr)
688     def test_remove_parens(self) -> None:
689         source, expected = read_data("remove_parens")
690         actual = fs(source)
691         self.assertFormatEqual(expected, actual)
692         black.assert_equivalent(source, actual)
693         black.assert_stable(source, actual, DEFAULT_MODE)
694
695     @patch("black.dump_to_file", dump_to_stderr)
696     def test_string_prefixes(self) -> None:
697         source, expected = read_data("string_prefixes")
698         actual = fs(source)
699         self.assertFormatEqual(expected, actual)
700         black.assert_equivalent(source, actual)
701         black.assert_stable(source, actual, DEFAULT_MODE)
702
703     @patch("black.dump_to_file", dump_to_stderr)
704     def test_numeric_literals(self) -> None:
705         source, expected = read_data("numeric_literals")
706         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
707         actual = fs(source, mode=mode)
708         self.assertFormatEqual(expected, actual)
709         black.assert_equivalent(source, actual)
710         black.assert_stable(source, actual, mode)
711
712     @patch("black.dump_to_file", dump_to_stderr)
713     def test_numeric_literals_ignoring_underscores(self) -> None:
714         source, expected = read_data("numeric_literals_skip_underscores")
715         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
716         actual = fs(source, mode=mode)
717         self.assertFormatEqual(expected, actual)
718         black.assert_equivalent(source, actual)
719         black.assert_stable(source, actual, mode)
720
721     @patch("black.dump_to_file", dump_to_stderr)
722     def test_numeric_literals_py2(self) -> None:
723         source, expected = read_data("numeric_literals_py2")
724         actual = fs(source)
725         self.assertFormatEqual(expected, actual)
726         black.assert_stable(source, actual, DEFAULT_MODE)
727
728     @patch("black.dump_to_file", dump_to_stderr)
729     def test_python2(self) -> None:
730         source, expected = read_data("python2")
731         actual = fs(source)
732         self.assertFormatEqual(expected, actual)
733         black.assert_equivalent(source, actual)
734         black.assert_stable(source, actual, DEFAULT_MODE)
735
736     @patch("black.dump_to_file", dump_to_stderr)
737     def test_python2_print_function(self) -> None:
738         source, expected = read_data("python2_print_function")
739         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
740         actual = fs(source, mode=mode)
741         self.assertFormatEqual(expected, actual)
742         black.assert_equivalent(source, actual)
743         black.assert_stable(source, actual, mode)
744
745     @patch("black.dump_to_file", dump_to_stderr)
746     def test_python2_unicode_literals(self) -> None:
747         source, expected = read_data("python2_unicode_literals")
748         actual = fs(source)
749         self.assertFormatEqual(expected, actual)
750         black.assert_equivalent(source, actual)
751         black.assert_stable(source, actual, DEFAULT_MODE)
752
753     @patch("black.dump_to_file", dump_to_stderr)
754     def test_stub(self) -> None:
755         mode = replace(DEFAULT_MODE, is_pyi=True)
756         source, expected = read_data("stub.pyi")
757         actual = fs(source, mode=mode)
758         self.assertFormatEqual(expected, actual)
759         black.assert_stable(source, actual, mode)
760
761     @patch("black.dump_to_file", dump_to_stderr)
762     def test_async_as_identifier(self) -> None:
763         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
764         source, expected = read_data("async_as_identifier")
765         actual = fs(source)
766         self.assertFormatEqual(expected, actual)
767         major, minor = sys.version_info[:2]
768         if major < 3 or (major <= 3 and minor < 7):
769             black.assert_equivalent(source, actual)
770         black.assert_stable(source, actual, DEFAULT_MODE)
771         # ensure black can parse this when the target is 3.6
772         self.invokeBlack([str(source_path), "--target-version", "py36"])
773         # but not on 3.7, because async/await is no longer an identifier
774         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
775
776     @patch("black.dump_to_file", dump_to_stderr)
777     def test_python37(self) -> None:
778         source_path = (THIS_DIR / "data" / "python37.py").resolve()
779         source, expected = read_data("python37")
780         actual = fs(source)
781         self.assertFormatEqual(expected, actual)
782         major, minor = sys.version_info[:2]
783         if major > 3 or (major == 3 and minor >= 7):
784             black.assert_equivalent(source, actual)
785         black.assert_stable(source, actual, DEFAULT_MODE)
786         # ensure black can parse this when the target is 3.7
787         self.invokeBlack([str(source_path), "--target-version", "py37"])
788         # but not on 3.6, because we use async as a reserved keyword
789         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
790
791     @patch("black.dump_to_file", dump_to_stderr)
792     def test_python38(self) -> None:
793         source, expected = read_data("python38")
794         actual = fs(source)
795         self.assertFormatEqual(expected, actual)
796         major, minor = sys.version_info[:2]
797         if major > 3 or (major == 3 and minor >= 8):
798             black.assert_equivalent(source, actual)
799         black.assert_stable(source, actual, DEFAULT_MODE)
800
801     @patch("black.dump_to_file", dump_to_stderr)
802     def test_fmtonoff(self) -> None:
803         source, expected = read_data("fmtonoff")
804         actual = fs(source)
805         self.assertFormatEqual(expected, actual)
806         black.assert_equivalent(source, actual)
807         black.assert_stable(source, actual, DEFAULT_MODE)
808
809     @patch("black.dump_to_file", dump_to_stderr)
810     def test_fmtonoff2(self) -> None:
811         source, expected = read_data("fmtonoff2")
812         actual = fs(source)
813         self.assertFormatEqual(expected, actual)
814         black.assert_equivalent(source, actual)
815         black.assert_stable(source, actual, DEFAULT_MODE)
816
817     @patch("black.dump_to_file", dump_to_stderr)
818     def test_fmtonoff3(self) -> None:
819         source, expected = read_data("fmtonoff3")
820         actual = fs(source)
821         self.assertFormatEqual(expected, actual)
822         black.assert_equivalent(source, actual)
823         black.assert_stable(source, actual, DEFAULT_MODE)
824
825     @patch("black.dump_to_file", dump_to_stderr)
826     def test_fmtonoff4(self) -> None:
827         source, expected = read_data("fmtonoff4")
828         actual = fs(source)
829         self.assertFormatEqual(expected, actual)
830         black.assert_equivalent(source, actual)
831         black.assert_stable(source, actual, DEFAULT_MODE)
832
833     @patch("black.dump_to_file", dump_to_stderr)
834     def test_remove_empty_parentheses_after_class(self) -> None:
835         source, expected = read_data("class_blank_parentheses")
836         actual = fs(source)
837         self.assertFormatEqual(expected, actual)
838         black.assert_equivalent(source, actual)
839         black.assert_stable(source, actual, DEFAULT_MODE)
840
841     @patch("black.dump_to_file", dump_to_stderr)
842     def test_new_line_between_class_and_code(self) -> None:
843         source, expected = read_data("class_methods_new_line")
844         actual = fs(source)
845         self.assertFormatEqual(expected, actual)
846         black.assert_equivalent(source, actual)
847         black.assert_stable(source, actual, DEFAULT_MODE)
848
849     @patch("black.dump_to_file", dump_to_stderr)
850     def test_bracket_match(self) -> None:
851         source, expected = read_data("bracketmatch")
852         actual = fs(source)
853         self.assertFormatEqual(expected, actual)
854         black.assert_equivalent(source, actual)
855         black.assert_stable(source, actual, DEFAULT_MODE)
856
857     @patch("black.dump_to_file", dump_to_stderr)
858     def test_tuple_assign(self) -> None:
859         source, expected = read_data("tupleassign")
860         actual = fs(source)
861         self.assertFormatEqual(expected, actual)
862         black.assert_equivalent(source, actual)
863         black.assert_stable(source, actual, DEFAULT_MODE)
864
865     @patch("black.dump_to_file", dump_to_stderr)
866     def test_beginning_backslash(self) -> None:
867         source, expected = read_data("beginning_backslash")
868         actual = fs(source)
869         self.assertFormatEqual(expected, actual)
870         black.assert_equivalent(source, actual)
871         black.assert_stable(source, actual, DEFAULT_MODE)
872
873     def test_tab_comment_indentation(self) -> None:
874         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
875         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
876         self.assertFormatEqual(contents_spc, fs(contents_spc))
877         self.assertFormatEqual(contents_spc, fs(contents_tab))
878
879         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
880         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
881         self.assertFormatEqual(contents_spc, fs(contents_spc))
882         self.assertFormatEqual(contents_spc, fs(contents_tab))
883
884         # mixed tabs and spaces (valid Python 2 code)
885         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
886         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
887         self.assertFormatEqual(contents_spc, fs(contents_spc))
888         self.assertFormatEqual(contents_spc, fs(contents_tab))
889
890         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
891         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
892         self.assertFormatEqual(contents_spc, fs(contents_spc))
893         self.assertFormatEqual(contents_spc, fs(contents_tab))
894
895     def test_report_verbose(self) -> None:
896         report = black.Report(verbose=True)
897         out_lines = []
898         err_lines = []
899
900         def out(msg: str, **kwargs: Any) -> None:
901             out_lines.append(msg)
902
903         def err(msg: str, **kwargs: Any) -> None:
904             err_lines.append(msg)
905
906         with patch("black.out", out), patch("black.err", err):
907             report.done(Path("f1"), black.Changed.NO)
908             self.assertEqual(len(out_lines), 1)
909             self.assertEqual(len(err_lines), 0)
910             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
911             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
912             self.assertEqual(report.return_code, 0)
913             report.done(Path("f2"), black.Changed.YES)
914             self.assertEqual(len(out_lines), 2)
915             self.assertEqual(len(err_lines), 0)
916             self.assertEqual(out_lines[-1], "reformatted f2")
917             self.assertEqual(
918                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
919             )
920             report.done(Path("f3"), black.Changed.CACHED)
921             self.assertEqual(len(out_lines), 3)
922             self.assertEqual(len(err_lines), 0)
923             self.assertEqual(
924                 out_lines[-1], "f3 wasn't modified on disk since last run."
925             )
926             self.assertEqual(
927                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
928             )
929             self.assertEqual(report.return_code, 0)
930             report.check = True
931             self.assertEqual(report.return_code, 1)
932             report.check = False
933             report.failed(Path("e1"), "boom")
934             self.assertEqual(len(out_lines), 3)
935             self.assertEqual(len(err_lines), 1)
936             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
937             self.assertEqual(
938                 unstyle(str(report)),
939                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
940                 " reformat.",
941             )
942             self.assertEqual(report.return_code, 123)
943             report.done(Path("f3"), black.Changed.YES)
944             self.assertEqual(len(out_lines), 4)
945             self.assertEqual(len(err_lines), 1)
946             self.assertEqual(out_lines[-1], "reformatted f3")
947             self.assertEqual(
948                 unstyle(str(report)),
949                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
950                 " reformat.",
951             )
952             self.assertEqual(report.return_code, 123)
953             report.failed(Path("e2"), "boom")
954             self.assertEqual(len(out_lines), 4)
955             self.assertEqual(len(err_lines), 2)
956             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
957             self.assertEqual(
958                 unstyle(str(report)),
959                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
960                 " reformat.",
961             )
962             self.assertEqual(report.return_code, 123)
963             report.path_ignored(Path("wat"), "no match")
964             self.assertEqual(len(out_lines), 5)
965             self.assertEqual(len(err_lines), 2)
966             self.assertEqual(out_lines[-1], "wat ignored: no match")
967             self.assertEqual(
968                 unstyle(str(report)),
969                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
970                 " reformat.",
971             )
972             self.assertEqual(report.return_code, 123)
973             report.done(Path("f4"), black.Changed.NO)
974             self.assertEqual(len(out_lines), 6)
975             self.assertEqual(len(err_lines), 2)
976             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
977             self.assertEqual(
978                 unstyle(str(report)),
979                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
980                 " reformat.",
981             )
982             self.assertEqual(report.return_code, 123)
983             report.check = True
984             self.assertEqual(
985                 unstyle(str(report)),
986                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
987                 " would fail to reformat.",
988             )
989             report.check = False
990             report.diff = True
991             self.assertEqual(
992                 unstyle(str(report)),
993                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
994                 " would fail to reformat.",
995             )
996
997     def test_report_quiet(self) -> None:
998         report = black.Report(quiet=True)
999         out_lines = []
1000         err_lines = []
1001
1002         def out(msg: str, **kwargs: Any) -> None:
1003             out_lines.append(msg)
1004
1005         def err(msg: str, **kwargs: Any) -> None:
1006             err_lines.append(msg)
1007
1008         with patch("black.out", out), patch("black.err", err):
1009             report.done(Path("f1"), black.Changed.NO)
1010             self.assertEqual(len(out_lines), 0)
1011             self.assertEqual(len(err_lines), 0)
1012             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1013             self.assertEqual(report.return_code, 0)
1014             report.done(Path("f2"), black.Changed.YES)
1015             self.assertEqual(len(out_lines), 0)
1016             self.assertEqual(len(err_lines), 0)
1017             self.assertEqual(
1018                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1019             )
1020             report.done(Path("f3"), black.Changed.CACHED)
1021             self.assertEqual(len(out_lines), 0)
1022             self.assertEqual(len(err_lines), 0)
1023             self.assertEqual(
1024                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1025             )
1026             self.assertEqual(report.return_code, 0)
1027             report.check = True
1028             self.assertEqual(report.return_code, 1)
1029             report.check = False
1030             report.failed(Path("e1"), "boom")
1031             self.assertEqual(len(out_lines), 0)
1032             self.assertEqual(len(err_lines), 1)
1033             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1034             self.assertEqual(
1035                 unstyle(str(report)),
1036                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1037                 " reformat.",
1038             )
1039             self.assertEqual(report.return_code, 123)
1040             report.done(Path("f3"), black.Changed.YES)
1041             self.assertEqual(len(out_lines), 0)
1042             self.assertEqual(len(err_lines), 1)
1043             self.assertEqual(
1044                 unstyle(str(report)),
1045                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1046                 " reformat.",
1047             )
1048             self.assertEqual(report.return_code, 123)
1049             report.failed(Path("e2"), "boom")
1050             self.assertEqual(len(out_lines), 0)
1051             self.assertEqual(len(err_lines), 2)
1052             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1053             self.assertEqual(
1054                 unstyle(str(report)),
1055                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1056                 " reformat.",
1057             )
1058             self.assertEqual(report.return_code, 123)
1059             report.path_ignored(Path("wat"), "no match")
1060             self.assertEqual(len(out_lines), 0)
1061             self.assertEqual(len(err_lines), 2)
1062             self.assertEqual(
1063                 unstyle(str(report)),
1064                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1065                 " reformat.",
1066             )
1067             self.assertEqual(report.return_code, 123)
1068             report.done(Path("f4"), black.Changed.NO)
1069             self.assertEqual(len(out_lines), 0)
1070             self.assertEqual(len(err_lines), 2)
1071             self.assertEqual(
1072                 unstyle(str(report)),
1073                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1074                 " reformat.",
1075             )
1076             self.assertEqual(report.return_code, 123)
1077             report.check = True
1078             self.assertEqual(
1079                 unstyle(str(report)),
1080                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1081                 " would fail to reformat.",
1082             )
1083             report.check = False
1084             report.diff = True
1085             self.assertEqual(
1086                 unstyle(str(report)),
1087                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1088                 " would fail to reformat.",
1089             )
1090
1091     def test_report_normal(self) -> None:
1092         report = black.Report()
1093         out_lines = []
1094         err_lines = []
1095
1096         def out(msg: str, **kwargs: Any) -> None:
1097             out_lines.append(msg)
1098
1099         def err(msg: str, **kwargs: Any) -> None:
1100             err_lines.append(msg)
1101
1102         with patch("black.out", out), patch("black.err", err):
1103             report.done(Path("f1"), black.Changed.NO)
1104             self.assertEqual(len(out_lines), 0)
1105             self.assertEqual(len(err_lines), 0)
1106             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1107             self.assertEqual(report.return_code, 0)
1108             report.done(Path("f2"), black.Changed.YES)
1109             self.assertEqual(len(out_lines), 1)
1110             self.assertEqual(len(err_lines), 0)
1111             self.assertEqual(out_lines[-1], "reformatted f2")
1112             self.assertEqual(
1113                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1114             )
1115             report.done(Path("f3"), black.Changed.CACHED)
1116             self.assertEqual(len(out_lines), 1)
1117             self.assertEqual(len(err_lines), 0)
1118             self.assertEqual(out_lines[-1], "reformatted f2")
1119             self.assertEqual(
1120                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1121             )
1122             self.assertEqual(report.return_code, 0)
1123             report.check = True
1124             self.assertEqual(report.return_code, 1)
1125             report.check = False
1126             report.failed(Path("e1"), "boom")
1127             self.assertEqual(len(out_lines), 1)
1128             self.assertEqual(len(err_lines), 1)
1129             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1130             self.assertEqual(
1131                 unstyle(str(report)),
1132                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1133                 " reformat.",
1134             )
1135             self.assertEqual(report.return_code, 123)
1136             report.done(Path("f3"), black.Changed.YES)
1137             self.assertEqual(len(out_lines), 2)
1138             self.assertEqual(len(err_lines), 1)
1139             self.assertEqual(out_lines[-1], "reformatted f3")
1140             self.assertEqual(
1141                 unstyle(str(report)),
1142                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1143                 " reformat.",
1144             )
1145             self.assertEqual(report.return_code, 123)
1146             report.failed(Path("e2"), "boom")
1147             self.assertEqual(len(out_lines), 2)
1148             self.assertEqual(len(err_lines), 2)
1149             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1150             self.assertEqual(
1151                 unstyle(str(report)),
1152                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1153                 " reformat.",
1154             )
1155             self.assertEqual(report.return_code, 123)
1156             report.path_ignored(Path("wat"), "no match")
1157             self.assertEqual(len(out_lines), 2)
1158             self.assertEqual(len(err_lines), 2)
1159             self.assertEqual(
1160                 unstyle(str(report)),
1161                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1162                 " reformat.",
1163             )
1164             self.assertEqual(report.return_code, 123)
1165             report.done(Path("f4"), black.Changed.NO)
1166             self.assertEqual(len(out_lines), 2)
1167             self.assertEqual(len(err_lines), 2)
1168             self.assertEqual(
1169                 unstyle(str(report)),
1170                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1171                 " reformat.",
1172             )
1173             self.assertEqual(report.return_code, 123)
1174             report.check = True
1175             self.assertEqual(
1176                 unstyle(str(report)),
1177                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1178                 " would fail to reformat.",
1179             )
1180             report.check = False
1181             report.diff = True
1182             self.assertEqual(
1183                 unstyle(str(report)),
1184                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1185                 " would fail to reformat.",
1186             )
1187
1188     def test_lib2to3_parse(self) -> None:
1189         with self.assertRaises(black.InvalidInput):
1190             black.lib2to3_parse("invalid syntax")
1191
1192         straddling = "x + y"
1193         black.lib2to3_parse(straddling)
1194         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1195         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1196         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1197
1198         py2_only = "print x"
1199         black.lib2to3_parse(py2_only)
1200         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1201         with self.assertRaises(black.InvalidInput):
1202             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1203         with self.assertRaises(black.InvalidInput):
1204             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1205
1206         py3_only = "exec(x, end=y)"
1207         black.lib2to3_parse(py3_only)
1208         with self.assertRaises(black.InvalidInput):
1209             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1210         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1211         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1212
1213     def test_get_features_used(self) -> None:
1214         node = black.lib2to3_parse("def f(*, arg): ...\n")
1215         self.assertEqual(black.get_features_used(node), set())
1216         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1217         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1218         node = black.lib2to3_parse("f(*arg,)\n")
1219         self.assertEqual(
1220             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1221         )
1222         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1223         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1224         node = black.lib2to3_parse("123_456\n")
1225         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1226         node = black.lib2to3_parse("123456\n")
1227         self.assertEqual(black.get_features_used(node), set())
1228         source, expected = read_data("function")
1229         node = black.lib2to3_parse(source)
1230         expected_features = {
1231             Feature.TRAILING_COMMA_IN_CALL,
1232             Feature.TRAILING_COMMA_IN_DEF,
1233             Feature.F_STRINGS,
1234         }
1235         self.assertEqual(black.get_features_used(node), expected_features)
1236         node = black.lib2to3_parse(expected)
1237         self.assertEqual(black.get_features_used(node), expected_features)
1238         source, expected = read_data("expression")
1239         node = black.lib2to3_parse(source)
1240         self.assertEqual(black.get_features_used(node), set())
1241         node = black.lib2to3_parse(expected)
1242         self.assertEqual(black.get_features_used(node), set())
1243
1244     def test_get_future_imports(self) -> None:
1245         node = black.lib2to3_parse("\n")
1246         self.assertEqual(set(), black.get_future_imports(node))
1247         node = black.lib2to3_parse("from __future__ import black\n")
1248         self.assertEqual({"black"}, black.get_future_imports(node))
1249         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1250         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1251         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1252         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1253         node = black.lib2to3_parse(
1254             "from __future__ import multiple\nfrom __future__ import imports\n"
1255         )
1256         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1257         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1258         self.assertEqual({"black"}, black.get_future_imports(node))
1259         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1260         self.assertEqual({"black"}, black.get_future_imports(node))
1261         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1262         self.assertEqual(set(), black.get_future_imports(node))
1263         node = black.lib2to3_parse("from some.module import black\n")
1264         self.assertEqual(set(), black.get_future_imports(node))
1265         node = black.lib2to3_parse(
1266             "from __future__ import unicode_literals as _unicode_literals"
1267         )
1268         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1269         node = black.lib2to3_parse(
1270             "from __future__ import unicode_literals as _lol, print"
1271         )
1272         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1273
1274     def test_debug_visitor(self) -> None:
1275         source, _ = read_data("debug_visitor.py")
1276         expected, _ = read_data("debug_visitor.out")
1277         out_lines = []
1278         err_lines = []
1279
1280         def out(msg: str, **kwargs: Any) -> None:
1281             out_lines.append(msg)
1282
1283         def err(msg: str, **kwargs: Any) -> None:
1284             err_lines.append(msg)
1285
1286         with patch("black.out", out), patch("black.err", err):
1287             black.DebugVisitor.show(source)
1288         actual = "\n".join(out_lines) + "\n"
1289         log_name = ""
1290         if expected != actual:
1291             log_name = black.dump_to_file(*out_lines)
1292         self.assertEqual(
1293             expected,
1294             actual,
1295             f"AST print out is different. Actual version dumped to {log_name}",
1296         )
1297
1298     def test_format_file_contents(self) -> None:
1299         empty = ""
1300         mode = DEFAULT_MODE
1301         with self.assertRaises(black.NothingChanged):
1302             black.format_file_contents(empty, mode=mode, fast=False)
1303         just_nl = "\n"
1304         with self.assertRaises(black.NothingChanged):
1305             black.format_file_contents(just_nl, mode=mode, fast=False)
1306         same = "j = [1, 2, 3]\n"
1307         with self.assertRaises(black.NothingChanged):
1308             black.format_file_contents(same, mode=mode, fast=False)
1309         different = "j = [1,2,3]"
1310         expected = same
1311         actual = black.format_file_contents(different, mode=mode, fast=False)
1312         self.assertEqual(expected, actual)
1313         invalid = "return if you can"
1314         with self.assertRaises(black.InvalidInput) as e:
1315             black.format_file_contents(invalid, mode=mode, fast=False)
1316         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1317
1318     def test_endmarker(self) -> None:
1319         n = black.lib2to3_parse("\n")
1320         self.assertEqual(n.type, black.syms.file_input)
1321         self.assertEqual(len(n.children), 1)
1322         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1323
1324     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1325     def test_assertFormatEqual(self) -> None:
1326         out_lines = []
1327         err_lines = []
1328
1329         def out(msg: str, **kwargs: Any) -> None:
1330             out_lines.append(msg)
1331
1332         def err(msg: str, **kwargs: Any) -> None:
1333             err_lines.append(msg)
1334
1335         with patch("black.out", out), patch("black.err", err):
1336             with self.assertRaises(AssertionError):
1337                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1338
1339         out_str = "".join(out_lines)
1340         self.assertTrue("Expected tree:" in out_str)
1341         self.assertTrue("Actual tree:" in out_str)
1342         self.assertEqual("".join(err_lines), "")
1343
1344     def test_cache_broken_file(self) -> None:
1345         mode = DEFAULT_MODE
1346         with cache_dir() as workspace:
1347             cache_file = black.get_cache_file(mode)
1348             with cache_file.open("w") as fobj:
1349                 fobj.write("this is not a pickle")
1350             self.assertEqual(black.read_cache(mode), {})
1351             src = (workspace / "test.py").resolve()
1352             with src.open("w") as fobj:
1353                 fobj.write("print('hello')")
1354             self.invokeBlack([str(src)])
1355             cache = black.read_cache(mode)
1356             self.assertIn(src, cache)
1357
1358     def test_cache_single_file_already_cached(self) -> None:
1359         mode = DEFAULT_MODE
1360         with cache_dir() as workspace:
1361             src = (workspace / "test.py").resolve()
1362             with src.open("w") as fobj:
1363                 fobj.write("print('hello')")
1364             black.write_cache({}, [src], mode)
1365             self.invokeBlack([str(src)])
1366             with src.open("r") as fobj:
1367                 self.assertEqual(fobj.read(), "print('hello')")
1368
1369     @event_loop()
1370     def test_cache_multiple_files(self) -> None:
1371         mode = DEFAULT_MODE
1372         with cache_dir() as workspace, patch(
1373             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1374         ):
1375             one = (workspace / "one.py").resolve()
1376             with one.open("w") as fobj:
1377                 fobj.write("print('hello')")
1378             two = (workspace / "two.py").resolve()
1379             with two.open("w") as fobj:
1380                 fobj.write("print('hello')")
1381             black.write_cache({}, [one], mode)
1382             self.invokeBlack([str(workspace)])
1383             with one.open("r") as fobj:
1384                 self.assertEqual(fobj.read(), "print('hello')")
1385             with two.open("r") as fobj:
1386                 self.assertEqual(fobj.read(), 'print("hello")\n')
1387             cache = black.read_cache(mode)
1388             self.assertIn(one, cache)
1389             self.assertIn(two, cache)
1390
1391     def test_no_cache_when_writeback_diff(self) -> None:
1392         mode = DEFAULT_MODE
1393         with cache_dir() as workspace:
1394             src = (workspace / "test.py").resolve()
1395             with src.open("w") as fobj:
1396                 fobj.write("print('hello')")
1397             self.invokeBlack([str(src), "--diff"])
1398             cache_file = black.get_cache_file(mode)
1399             self.assertFalse(cache_file.exists())
1400
1401     def test_no_cache_when_stdin(self) -> None:
1402         mode = DEFAULT_MODE
1403         with cache_dir():
1404             result = CliRunner().invoke(
1405                 black.main, ["-"], input=BytesIO(b"print('hello')")
1406             )
1407             self.assertEqual(result.exit_code, 0)
1408             cache_file = black.get_cache_file(mode)
1409             self.assertFalse(cache_file.exists())
1410
1411     def test_read_cache_no_cachefile(self) -> None:
1412         mode = DEFAULT_MODE
1413         with cache_dir():
1414             self.assertEqual(black.read_cache(mode), {})
1415
1416     def test_write_cache_read_cache(self) -> None:
1417         mode = DEFAULT_MODE
1418         with cache_dir() as workspace:
1419             src = (workspace / "test.py").resolve()
1420             src.touch()
1421             black.write_cache({}, [src], mode)
1422             cache = black.read_cache(mode)
1423             self.assertIn(src, cache)
1424             self.assertEqual(cache[src], black.get_cache_info(src))
1425
1426     def test_filter_cached(self) -> None:
1427         with TemporaryDirectory() as workspace:
1428             path = Path(workspace)
1429             uncached = (path / "uncached").resolve()
1430             cached = (path / "cached").resolve()
1431             cached_but_changed = (path / "changed").resolve()
1432             uncached.touch()
1433             cached.touch()
1434             cached_but_changed.touch()
1435             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1436             todo, done = black.filter_cached(
1437                 cache, {uncached, cached, cached_but_changed}
1438             )
1439             self.assertEqual(todo, {uncached, cached_but_changed})
1440             self.assertEqual(done, {cached})
1441
1442     def test_write_cache_creates_directory_if_needed(self) -> None:
1443         mode = DEFAULT_MODE
1444         with cache_dir(exists=False) as workspace:
1445             self.assertFalse(workspace.exists())
1446             black.write_cache({}, [], mode)
1447             self.assertTrue(workspace.exists())
1448
1449     @event_loop()
1450     def test_failed_formatting_does_not_get_cached(self) -> None:
1451         mode = DEFAULT_MODE
1452         with cache_dir() as workspace, patch(
1453             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1454         ):
1455             failing = (workspace / "failing.py").resolve()
1456             with failing.open("w") as fobj:
1457                 fobj.write("not actually python")
1458             clean = (workspace / "clean.py").resolve()
1459             with clean.open("w") as fobj:
1460                 fobj.write('print("hello")\n')
1461             self.invokeBlack([str(workspace)], exit_code=123)
1462             cache = black.read_cache(mode)
1463             self.assertNotIn(failing, cache)
1464             self.assertIn(clean, cache)
1465
1466     def test_write_cache_write_fail(self) -> None:
1467         mode = DEFAULT_MODE
1468         with cache_dir(), patch.object(Path, "open") as mock:
1469             mock.side_effect = OSError
1470             black.write_cache({}, [], mode)
1471
1472     @event_loop()
1473     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1474     def test_works_in_mono_process_only_environment(self) -> None:
1475         with cache_dir() as workspace:
1476             for f in [
1477                 (workspace / "one.py").resolve(),
1478                 (workspace / "two.py").resolve(),
1479             ]:
1480                 f.write_text('print("hello")\n')
1481             self.invokeBlack([str(workspace)])
1482
1483     @event_loop()
1484     def test_check_diff_use_together(self) -> None:
1485         with cache_dir():
1486             # Files which will be reformatted.
1487             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1488             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1489             # Files which will not be reformatted.
1490             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1491             self.invokeBlack([str(src2), "--diff", "--check"])
1492             # Multi file command.
1493             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1494
1495     def test_no_files(self) -> None:
1496         with cache_dir():
1497             # Without an argument, black exits with error code 0.
1498             self.invokeBlack([])
1499
1500     def test_broken_symlink(self) -> None:
1501         with cache_dir() as workspace:
1502             symlink = workspace / "broken_link.py"
1503             try:
1504                 symlink.symlink_to("nonexistent.py")
1505             except OSError as e:
1506                 self.skipTest(f"Can't create symlinks: {e}")
1507             self.invokeBlack([str(workspace.resolve())])
1508
1509     def test_read_cache_line_lengths(self) -> None:
1510         mode = DEFAULT_MODE
1511         short_mode = replace(DEFAULT_MODE, line_length=1)
1512         with cache_dir() as workspace:
1513             path = (workspace / "file.py").resolve()
1514             path.touch()
1515             black.write_cache({}, [path], mode)
1516             one = black.read_cache(mode)
1517             self.assertIn(path, one)
1518             two = black.read_cache(short_mode)
1519             self.assertNotIn(path, two)
1520
1521     def test_tricky_unicode_symbols(self) -> None:
1522         source, expected = read_data("tricky_unicode_symbols")
1523         actual = fs(source)
1524         self.assertFormatEqual(expected, actual)
1525         black.assert_equivalent(source, actual)
1526         black.assert_stable(source, actual, DEFAULT_MODE)
1527
1528     def test_single_file_force_pyi(self) -> None:
1529         reg_mode = DEFAULT_MODE
1530         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1531         contents, expected = read_data("force_pyi")
1532         with cache_dir() as workspace:
1533             path = (workspace / "file.py").resolve()
1534             with open(path, "w") as fh:
1535                 fh.write(contents)
1536             self.invokeBlack([str(path), "--pyi"])
1537             with open(path, "r") as fh:
1538                 actual = fh.read()
1539             # verify cache with --pyi is separate
1540             pyi_cache = black.read_cache(pyi_mode)
1541             self.assertIn(path, pyi_cache)
1542             normal_cache = black.read_cache(reg_mode)
1543             self.assertNotIn(path, normal_cache)
1544         self.assertEqual(actual, expected)
1545
1546     @event_loop()
1547     def test_multi_file_force_pyi(self) -> None:
1548         reg_mode = DEFAULT_MODE
1549         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1550         contents, expected = read_data("force_pyi")
1551         with cache_dir() as workspace:
1552             paths = [
1553                 (workspace / "file1.py").resolve(),
1554                 (workspace / "file2.py").resolve(),
1555             ]
1556             for path in paths:
1557                 with open(path, "w") as fh:
1558                     fh.write(contents)
1559             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1560             for path in paths:
1561                 with open(path, "r") as fh:
1562                     actual = fh.read()
1563                 self.assertEqual(actual, expected)
1564             # verify cache with --pyi is separate
1565             pyi_cache = black.read_cache(pyi_mode)
1566             normal_cache = black.read_cache(reg_mode)
1567             for path in paths:
1568                 self.assertIn(path, pyi_cache)
1569                 self.assertNotIn(path, normal_cache)
1570
1571     def test_pipe_force_pyi(self) -> None:
1572         source, expected = read_data("force_pyi")
1573         result = CliRunner().invoke(
1574             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1575         )
1576         self.assertEqual(result.exit_code, 0)
1577         actual = result.output
1578         self.assertFormatEqual(actual, expected)
1579
1580     def test_single_file_force_py36(self) -> None:
1581         reg_mode = DEFAULT_MODE
1582         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1583         source, expected = read_data("force_py36")
1584         with cache_dir() as workspace:
1585             path = (workspace / "file.py").resolve()
1586             with open(path, "w") as fh:
1587                 fh.write(source)
1588             self.invokeBlack([str(path), *PY36_ARGS])
1589             with open(path, "r") as fh:
1590                 actual = fh.read()
1591             # verify cache with --target-version is separate
1592             py36_cache = black.read_cache(py36_mode)
1593             self.assertIn(path, py36_cache)
1594             normal_cache = black.read_cache(reg_mode)
1595             self.assertNotIn(path, normal_cache)
1596         self.assertEqual(actual, expected)
1597
1598     @event_loop()
1599     def test_multi_file_force_py36(self) -> None:
1600         reg_mode = DEFAULT_MODE
1601         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1602         source, expected = read_data("force_py36")
1603         with cache_dir() as workspace:
1604             paths = [
1605                 (workspace / "file1.py").resolve(),
1606                 (workspace / "file2.py").resolve(),
1607             ]
1608             for path in paths:
1609                 with open(path, "w") as fh:
1610                     fh.write(source)
1611             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1612             for path in paths:
1613                 with open(path, "r") as fh:
1614                     actual = fh.read()
1615                 self.assertEqual(actual, expected)
1616             # verify cache with --target-version is separate
1617             pyi_cache = black.read_cache(py36_mode)
1618             normal_cache = black.read_cache(reg_mode)
1619             for path in paths:
1620                 self.assertIn(path, pyi_cache)
1621                 self.assertNotIn(path, normal_cache)
1622
1623     def test_collections(self) -> None:
1624         source, expected = read_data("collections")
1625         actual = fs(source)
1626         self.assertFormatEqual(expected, actual)
1627         black.assert_equivalent(source, actual)
1628         black.assert_stable(source, actual, DEFAULT_MODE)
1629
1630     def test_pipe_force_py36(self) -> None:
1631         source, expected = read_data("force_py36")
1632         result = CliRunner().invoke(
1633             black.main,
1634             ["-", "-q", "--target-version=py36"],
1635             input=BytesIO(source.encode("utf8")),
1636         )
1637         self.assertEqual(result.exit_code, 0)
1638         actual = result.output
1639         self.assertFormatEqual(actual, expected)
1640
1641     def test_include_exclude(self) -> None:
1642         path = THIS_DIR / "data" / "include_exclude_tests"
1643         include = re.compile(r"\.pyi?$")
1644         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1645         report = black.Report()
1646         gitignore = PathSpec.from_lines("gitwildmatch", [])
1647         sources: List[Path] = []
1648         expected = [
1649             Path(path / "b/dont_exclude/a.py"),
1650             Path(path / "b/dont_exclude/a.pyi"),
1651         ]
1652         this_abs = THIS_DIR.resolve()
1653         sources.extend(
1654             black.gen_python_files(
1655                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1656             )
1657         )
1658         self.assertEqual(sorted(expected), sorted(sources))
1659
1660     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1661     def test_exclude_for_issue_1572(self) -> None:
1662         # Exclude shouldn't touch files that were explicitly given to Black through the
1663         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1664         # https://github.com/psf/black/issues/1572
1665         path = THIS_DIR / "data" / "include_exclude_tests"
1666         include = ""
1667         exclude = r"/exclude/|a\.py"
1668         src = str(path / "b/exclude/a.py")
1669         report = black.Report()
1670         expected = [Path(path / "b/exclude/a.py")]
1671         sources = list(
1672             black.get_sources(
1673                 ctx=FakeContext(),
1674                 src=(src,),
1675                 quiet=True,
1676                 verbose=False,
1677                 include=include,
1678                 exclude=exclude,
1679                 force_exclude=None,
1680                 report=report,
1681             )
1682         )
1683         self.assertEqual(sorted(expected), sorted(sources))
1684
1685     def test_gitignore_exclude(self) -> None:
1686         path = THIS_DIR / "data" / "include_exclude_tests"
1687         include = re.compile(r"\.pyi?$")
1688         exclude = re.compile(r"")
1689         report = black.Report()
1690         gitignore = PathSpec.from_lines(
1691             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1692         )
1693         sources: List[Path] = []
1694         expected = [
1695             Path(path / "b/dont_exclude/a.py"),
1696             Path(path / "b/dont_exclude/a.pyi"),
1697         ]
1698         this_abs = THIS_DIR.resolve()
1699         sources.extend(
1700             black.gen_python_files(
1701                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1702             )
1703         )
1704         self.assertEqual(sorted(expected), sorted(sources))
1705
1706     def test_empty_include(self) -> None:
1707         path = THIS_DIR / "data" / "include_exclude_tests"
1708         report = black.Report()
1709         gitignore = PathSpec.from_lines("gitwildmatch", [])
1710         empty = re.compile(r"")
1711         sources: List[Path] = []
1712         expected = [
1713             Path(path / "b/exclude/a.pie"),
1714             Path(path / "b/exclude/a.py"),
1715             Path(path / "b/exclude/a.pyi"),
1716             Path(path / "b/dont_exclude/a.pie"),
1717             Path(path / "b/dont_exclude/a.py"),
1718             Path(path / "b/dont_exclude/a.pyi"),
1719             Path(path / "b/.definitely_exclude/a.pie"),
1720             Path(path / "b/.definitely_exclude/a.py"),
1721             Path(path / "b/.definitely_exclude/a.pyi"),
1722         ]
1723         this_abs = THIS_DIR.resolve()
1724         sources.extend(
1725             black.gen_python_files(
1726                 path.iterdir(),
1727                 this_abs,
1728                 empty,
1729                 re.compile(black.DEFAULT_EXCLUDES),
1730                 None,
1731                 report,
1732                 gitignore,
1733             )
1734         )
1735         self.assertEqual(sorted(expected), sorted(sources))
1736
1737     def test_empty_exclude(self) -> None:
1738         path = THIS_DIR / "data" / "include_exclude_tests"
1739         report = black.Report()
1740         gitignore = PathSpec.from_lines("gitwildmatch", [])
1741         empty = re.compile(r"")
1742         sources: List[Path] = []
1743         expected = [
1744             Path(path / "b/dont_exclude/a.py"),
1745             Path(path / "b/dont_exclude/a.pyi"),
1746             Path(path / "b/exclude/a.py"),
1747             Path(path / "b/exclude/a.pyi"),
1748             Path(path / "b/.definitely_exclude/a.py"),
1749             Path(path / "b/.definitely_exclude/a.pyi"),
1750         ]
1751         this_abs = THIS_DIR.resolve()
1752         sources.extend(
1753             black.gen_python_files(
1754                 path.iterdir(),
1755                 this_abs,
1756                 re.compile(black.DEFAULT_INCLUDES),
1757                 empty,
1758                 None,
1759                 report,
1760                 gitignore,
1761             )
1762         )
1763         self.assertEqual(sorted(expected), sorted(sources))
1764
1765     def test_invalid_include_exclude(self) -> None:
1766         for option in ["--include", "--exclude"]:
1767             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1768
1769     def test_preserves_line_endings(self) -> None:
1770         with TemporaryDirectory() as workspace:
1771             test_file = Path(workspace) / "test.py"
1772             for nl in ["\n", "\r\n"]:
1773                 contents = nl.join(["def f(  ):", "    pass"])
1774                 test_file.write_bytes(contents.encode())
1775                 ff(test_file, write_back=black.WriteBack.YES)
1776                 updated_contents: bytes = test_file.read_bytes()
1777                 self.assertIn(nl.encode(), updated_contents)
1778                 if nl == "\n":
1779                     self.assertNotIn(b"\r\n", updated_contents)
1780
1781     def test_preserves_line_endings_via_stdin(self) -> None:
1782         for nl in ["\n", "\r\n"]:
1783             contents = nl.join(["def f(  ):", "    pass"])
1784             runner = BlackRunner()
1785             result = runner.invoke(
1786                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1787             )
1788             self.assertEqual(result.exit_code, 0)
1789             output = runner.stdout_bytes
1790             self.assertIn(nl.encode("utf8"), output)
1791             if nl == "\n":
1792                 self.assertNotIn(b"\r\n", output)
1793
1794     def test_assert_equivalent_different_asts(self) -> None:
1795         with self.assertRaises(AssertionError):
1796             black.assert_equivalent("{}", "None")
1797
1798     def test_symlink_out_of_root_directory(self) -> None:
1799         path = MagicMock()
1800         root = THIS_DIR.resolve()
1801         child = MagicMock()
1802         include = re.compile(black.DEFAULT_INCLUDES)
1803         exclude = re.compile(black.DEFAULT_EXCLUDES)
1804         report = black.Report()
1805         gitignore = PathSpec.from_lines("gitwildmatch", [])
1806         # `child` should behave like a symlink which resolved path is clearly
1807         # outside of the `root` directory.
1808         path.iterdir.return_value = [child]
1809         child.resolve.return_value = Path("/a/b/c")
1810         child.as_posix.return_value = "/a/b/c"
1811         child.is_symlink.return_value = True
1812         try:
1813             list(
1814                 black.gen_python_files(
1815                     path.iterdir(), root, include, exclude, None, report, gitignore
1816                 )
1817             )
1818         except ValueError as ve:
1819             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1820         path.iterdir.assert_called_once()
1821         child.resolve.assert_called_once()
1822         child.is_symlink.assert_called_once()
1823         # `child` should behave like a strange file which resolved path is clearly
1824         # outside of the `root` directory.
1825         child.is_symlink.return_value = False
1826         with self.assertRaises(ValueError):
1827             list(
1828                 black.gen_python_files(
1829                     path.iterdir(), root, include, exclude, None, report, gitignore
1830                 )
1831             )
1832         path.iterdir.assert_called()
1833         self.assertEqual(path.iterdir.call_count, 2)
1834         child.resolve.assert_called()
1835         self.assertEqual(child.resolve.call_count, 2)
1836         child.is_symlink.assert_called()
1837         self.assertEqual(child.is_symlink.call_count, 2)
1838
1839     def test_shhh_click(self) -> None:
1840         try:
1841             from click import _unicodefun  # type: ignore
1842         except ModuleNotFoundError:
1843             self.skipTest("Incompatible Click version")
1844         if not hasattr(_unicodefun, "_verify_python3_env"):
1845             self.skipTest("Incompatible Click version")
1846         # First, let's see if Click is crashing with a preferred ASCII charset.
1847         with patch("locale.getpreferredencoding") as gpe:
1848             gpe.return_value = "ASCII"
1849             with self.assertRaises(RuntimeError):
1850                 _unicodefun._verify_python3_env()
1851         # Now, let's silence Click...
1852         black.patch_click()
1853         # ...and confirm it's silent.
1854         with patch("locale.getpreferredencoding") as gpe:
1855             gpe.return_value = "ASCII"
1856             try:
1857                 _unicodefun._verify_python3_env()
1858             except RuntimeError as re:
1859                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1860
1861     def test_root_logger_not_used_directly(self) -> None:
1862         def fail(*args: Any, **kwargs: Any) -> None:
1863             self.fail("Record created with root logger")
1864
1865         with patch.multiple(
1866             logging.root,
1867             debug=fail,
1868             info=fail,
1869             warning=fail,
1870             error=fail,
1871             critical=fail,
1872             log=fail,
1873         ):
1874             ff(THIS_FILE)
1875
1876     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1877     def test_blackd_main(self) -> None:
1878         with patch("blackd.web.run_app"):
1879             result = CliRunner().invoke(blackd.main, [])
1880             if result.exception is not None:
1881                 raise result.exception
1882             self.assertEqual(result.exit_code, 0)
1883
1884     def test_invalid_config_return_code(self) -> None:
1885         tmp_file = Path(black.dump_to_file())
1886         try:
1887             tmp_config = Path(black.dump_to_file())
1888             tmp_config.unlink()
1889             args = ["--config", str(tmp_config), str(tmp_file)]
1890             self.invokeBlack(args, exit_code=2, ignore_config=False)
1891         finally:
1892             tmp_file.unlink()
1893
1894     def test_parse_pyproject_toml(self) -> None:
1895         test_toml_file = THIS_DIR / "test.toml"
1896         config = black.parse_pyproject_toml(str(test_toml_file))
1897         self.assertEqual(config["verbose"], 1)
1898         self.assertEqual(config["check"], "no")
1899         self.assertEqual(config["diff"], "y")
1900         self.assertEqual(config["color"], True)
1901         self.assertEqual(config["line_length"], 79)
1902         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1903         self.assertEqual(config["exclude"], r"\.pyi?$")
1904         self.assertEqual(config["include"], r"\.py?$")
1905
1906     def test_read_pyproject_toml(self) -> None:
1907         test_toml_file = THIS_DIR / "test.toml"
1908         fake_ctx = FakeContext()
1909         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1910         config = fake_ctx.default_map
1911         self.assertEqual(config["verbose"], "1")
1912         self.assertEqual(config["check"], "no")
1913         self.assertEqual(config["diff"], "y")
1914         self.assertEqual(config["color"], "True")
1915         self.assertEqual(config["line_length"], "79")
1916         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1917         self.assertEqual(config["exclude"], r"\.pyi?$")
1918         self.assertEqual(config["include"], r"\.py?$")
1919
1920     def test_find_project_root(self) -> None:
1921         with TemporaryDirectory() as workspace:
1922             root = Path(workspace)
1923             test_dir = root / "test"
1924             test_dir.mkdir()
1925
1926             src_dir = root / "src"
1927             src_dir.mkdir()
1928
1929             root_pyproject = root / "pyproject.toml"
1930             root_pyproject.touch()
1931             src_pyproject = src_dir / "pyproject.toml"
1932             src_pyproject.touch()
1933             src_python = src_dir / "foo.py"
1934             src_python.touch()
1935
1936             self.assertEqual(
1937                 black.find_project_root((src_dir, test_dir)), root.resolve()
1938             )
1939             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1940             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1941
1942
1943 class BlackDTestCase(AioHTTPTestCase):
1944     async def get_application(self) -> web.Application:
1945         return blackd.make_app()
1946
1947     # TODO: remove these decorators once the below is released
1948     # https://github.com/aio-libs/aiohttp/pull/3727
1949     @skip_if_exception("ClientOSError")
1950     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1951     @unittest_run_loop
1952     async def test_blackd_request_needs_formatting(self) -> None:
1953         response = await self.client.post("/", data=b"print('hello world')")
1954         self.assertEqual(response.status, 200)
1955         self.assertEqual(response.charset, "utf8")
1956         self.assertEqual(await response.read(), b'print("hello world")\n')
1957
1958     @skip_if_exception("ClientOSError")
1959     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1960     @unittest_run_loop
1961     async def test_blackd_request_no_change(self) -> None:
1962         response = await self.client.post("/", data=b'print("hello world")\n')
1963         self.assertEqual(response.status, 204)
1964         self.assertEqual(await response.read(), b"")
1965
1966     @skip_if_exception("ClientOSError")
1967     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1968     @unittest_run_loop
1969     async def test_blackd_request_syntax_error(self) -> None:
1970         response = await self.client.post("/", data=b"what even ( is")
1971         self.assertEqual(response.status, 400)
1972         content = await response.text()
1973         self.assertTrue(
1974             content.startswith("Cannot parse"),
1975             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1976         )
1977
1978     @skip_if_exception("ClientOSError")
1979     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1980     @unittest_run_loop
1981     async def test_blackd_unsupported_version(self) -> None:
1982         response = await self.client.post(
1983             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1984         )
1985         self.assertEqual(response.status, 501)
1986
1987     @skip_if_exception("ClientOSError")
1988     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1989     @unittest_run_loop
1990     async def test_blackd_supported_version(self) -> None:
1991         response = await self.client.post(
1992             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1993         )
1994         self.assertEqual(response.status, 200)
1995
1996     @skip_if_exception("ClientOSError")
1997     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1998     @unittest_run_loop
1999     async def test_blackd_invalid_python_variant(self) -> None:
2000         async def check(header_value: str, expected_status: int = 400) -> None:
2001             response = await self.client.post(
2002                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2003             )
2004             self.assertEqual(response.status, expected_status)
2005
2006         await check("lol")
2007         await check("ruby3.5")
2008         await check("pyi3.6")
2009         await check("py1.5")
2010         await check("2.8")
2011         await check("py2.8")
2012         await check("3.0")
2013         await check("pypy3.0")
2014         await check("jython3.4")
2015
2016     @skip_if_exception("ClientOSError")
2017     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2018     @unittest_run_loop
2019     async def test_blackd_pyi(self) -> None:
2020         source, expected = read_data("stub.pyi")
2021         response = await self.client.post(
2022             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
2023         )
2024         self.assertEqual(response.status, 200)
2025         self.assertEqual(await response.text(), expected)
2026
2027     @skip_if_exception("ClientOSError")
2028     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2029     @unittest_run_loop
2030     async def test_blackd_diff(self) -> None:
2031         diff_header = re.compile(
2032             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2033         )
2034
2035         source, _ = read_data("blackd_diff.py")
2036         expected, _ = read_data("blackd_diff.diff")
2037
2038         response = await self.client.post(
2039             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2040         )
2041         self.assertEqual(response.status, 200)
2042
2043         actual = await response.text()
2044         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2045         self.assertEqual(actual, expected)
2046
2047     @skip_if_exception("ClientOSError")
2048     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2049     @unittest_run_loop
2050     async def test_blackd_python_variant(self) -> None:
2051         code = (
2052             "def f(\n"
2053             "    and_has_a_bunch_of,\n"
2054             "    very_long_arguments_too,\n"
2055             "    and_lots_of_them_as_well_lol,\n"
2056             "    **and_very_long_keyword_arguments\n"
2057             "):\n"
2058             "    pass\n"
2059         )
2060
2061         async def check(header_value: str, expected_status: int) -> None:
2062             response = await self.client.post(
2063                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2064             )
2065             self.assertEqual(
2066                 response.status, expected_status, msg=await response.text()
2067             )
2068
2069         await check("3.6", 200)
2070         await check("py3.6", 200)
2071         await check("3.6,3.7", 200)
2072         await check("3.6,py3.7", 200)
2073         await check("py36,py37", 200)
2074         await check("36", 200)
2075         await check("3.6.4", 200)
2076
2077         await check("2", 204)
2078         await check("2.7", 204)
2079         await check("py2.7", 204)
2080         await check("3.4", 204)
2081         await check("py3.4", 204)
2082         await check("py34,py36", 204)
2083         await check("34", 204)
2084
2085     @skip_if_exception("ClientOSError")
2086     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2087     @unittest_run_loop
2088     async def test_blackd_line_length(self) -> None:
2089         response = await self.client.post(
2090             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2091         )
2092         self.assertEqual(response.status, 200)
2093
2094     @skip_if_exception("ClientOSError")
2095     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2096     @unittest_run_loop
2097     async def test_blackd_invalid_line_length(self) -> None:
2098         response = await self.client.post(
2099             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2100         )
2101         self.assertEqual(response.status, 400)
2102
2103     @skip_if_exception("ClientOSError")
2104     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2105     @unittest_run_loop
2106     async def test_blackd_response_black_version_header(self) -> None:
2107         response = await self.client.post("/")
2108         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2109
2110
2111 with open(black.__file__, "r", encoding="utf-8") as _bf:
2112     black_source_lines = _bf.readlines()
2113
2114
2115 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2116     """Show function calls `from black/__init__.py` as they happen.
2117
2118     Register this with `sys.settrace()` in a test you're debugging.
2119     """
2120     if event != "call":
2121         return tracefunc
2122
2123     stack = len(inspect.stack()) - 19
2124     stack *= 2
2125     filename = frame.f_code.co_filename
2126     lineno = frame.f_lineno
2127     func_sig_lineno = lineno - 1
2128     funcname = black_source_lines[func_sig_lineno].strip()
2129     while funcname.startswith("@"):
2130         func_sig_lineno += 1
2131         funcname = black_source_lines[func_sig_lineno].strip()
2132     if "black/__init__.py" in filename:
2133         print(f"{' ' * stack}{lineno}:{funcname}")
2134     return tracefunc
2135
2136
2137 if __name__ == "__main__":
2138     unittest.main(module="test_black")