]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

f89f1403abad851f349026443e84d526d2255cb3
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_function_trailing_comma_wip(self) -> None:
350         source, expected = read_data("function_trailing_comma_wip")
351         # sys.settrace(tracefunc)
352         actual = fs(source)
353         # sys.settrace(None)
354         self.assertFormatEqual(expected, actual)
355         black.assert_equivalent(source, actual)
356         black.assert_stable(source, actual, black.FileMode())
357
358     @patch("black.dump_to_file", dump_to_stderr)
359     def test_function_trailing_comma(self) -> None:
360         source, expected = read_data("function_trailing_comma")
361         actual = fs(source)
362         self.assertFormatEqual(expected, actual)
363         black.assert_equivalent(source, actual)
364         black.assert_stable(source, actual, DEFAULT_MODE)
365
366     @patch("black.dump_to_file", dump_to_stderr)
367     def test_expression(self) -> None:
368         source, expected = read_data("expression")
369         actual = fs(source)
370         self.assertFormatEqual(expected, actual)
371         black.assert_equivalent(source, actual)
372         black.assert_stable(source, actual, DEFAULT_MODE)
373
374     @patch("black.dump_to_file", dump_to_stderr)
375     def test_pep_572(self) -> None:
376         source, expected = read_data("pep_572")
377         actual = fs(source)
378         self.assertFormatEqual(expected, actual)
379         black.assert_stable(source, actual, DEFAULT_MODE)
380         if sys.version_info >= (3, 8):
381             black.assert_equivalent(source, actual)
382
383     def test_pep_572_version_detection(self) -> None:
384         source, _ = read_data("pep_572")
385         root = black.lib2to3_parse(source)
386         features = black.get_features_used(root)
387         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
388         versions = black.detect_target_versions(root)
389         self.assertIn(black.TargetVersion.PY38, versions)
390
391     def test_expression_ff(self) -> None:
392         source, expected = read_data("expression")
393         tmp_file = Path(black.dump_to_file(source))
394         try:
395             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
396             with open(tmp_file, encoding="utf8") as f:
397                 actual = f.read()
398         finally:
399             os.unlink(tmp_file)
400         self.assertFormatEqual(expected, actual)
401         with patch("black.dump_to_file", dump_to_stderr):
402             black.assert_equivalent(source, actual)
403             black.assert_stable(source, actual, DEFAULT_MODE)
404
405     def test_expression_diff(self) -> None:
406         source, _ = read_data("expression.py")
407         expected, _ = read_data("expression.diff")
408         tmp_file = Path(black.dump_to_file(source))
409         diff_header = re.compile(
410             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
411             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
412         )
413         try:
414             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
415             self.assertEqual(result.exit_code, 0)
416         finally:
417             os.unlink(tmp_file)
418         actual = result.output
419         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
420         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
421         if expected != actual:
422             dump = black.dump_to_file(actual)
423             msg = (
424                 "Expected diff isn't equal to the actual. If you made changes to"
425                 " expression.py and this is an anticipated difference, overwrite"
426                 f" tests/data/expression.diff with {dump}"
427             )
428             self.assertEqual(expected, actual, msg)
429
430     def test_expression_diff_with_color(self) -> None:
431         source, _ = read_data("expression.py")
432         expected, _ = read_data("expression.diff")
433         tmp_file = Path(black.dump_to_file(source))
434         try:
435             result = BlackRunner().invoke(
436                 black.main, ["--diff", "--color", str(tmp_file)]
437             )
438         finally:
439             os.unlink(tmp_file)
440         actual = result.output
441         # We check the contents of the diff in `test_expression_diff`. All
442         # we need to check here is that color codes exist in the result.
443         self.assertIn("\033[1;37m", actual)
444         self.assertIn("\033[36m", actual)
445         self.assertIn("\033[32m", actual)
446         self.assertIn("\033[31m", actual)
447         self.assertIn("\033[0m", actual)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_fstring(self) -> None:
451         source, expected = read_data("fstring")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, DEFAULT_MODE)
456
457     @patch("black.dump_to_file", dump_to_stderr)
458     def test_pep_570(self) -> None:
459         source, expected = read_data("pep_570")
460         actual = fs(source)
461         self.assertFormatEqual(expected, actual)
462         black.assert_stable(source, actual, DEFAULT_MODE)
463         if sys.version_info >= (3, 8):
464             black.assert_equivalent(source, actual)
465
466     def test_detect_pos_only_arguments(self) -> None:
467         source, _ = read_data("pep_570")
468         root = black.lib2to3_parse(source)
469         features = black.get_features_used(root)
470         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
471         versions = black.detect_target_versions(root)
472         self.assertIn(black.TargetVersion.PY38, versions)
473
474     @patch("black.dump_to_file", dump_to_stderr)
475     def test_string_quotes(self) -> None:
476         source, expected = read_data("string_quotes")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         black.assert_equivalent(source, actual)
480         black.assert_stable(source, actual, DEFAULT_MODE)
481         mode = replace(DEFAULT_MODE, string_normalization=False)
482         not_normalized = fs(source, mode=mode)
483         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
484         black.assert_equivalent(source, not_normalized)
485         black.assert_stable(source, not_normalized, mode=mode)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_docstring(self) -> None:
489         source, expected = read_data("docstring")
490         actual = fs(source)
491         self.assertFormatEqual(expected, actual)
492         black.assert_equivalent(source, actual)
493         black.assert_stable(source, actual, DEFAULT_MODE)
494
495     def test_long_strings(self) -> None:
496         """Tests for splitting long strings."""
497         source, expected = read_data("long_strings")
498         actual = fs(source)
499         self.assertFormatEqual(expected, actual)
500         black.assert_equivalent(source, actual)
501         black.assert_stable(source, actual, DEFAULT_MODE)
502
503     def test_long_strings_flag_disabled(self) -> None:
504         """Tests for turning off the string processing logic."""
505         source, expected = read_data("long_strings_flag_disabled")
506         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
507         actual = fs(source, mode=mode)
508         self.assertFormatEqual(expected, actual)
509         black.assert_stable(expected, actual, mode)
510
511     @patch("black.dump_to_file", dump_to_stderr)
512     def test_long_strings__edge_case(self) -> None:
513         """Edge-case tests for splitting long strings."""
514         source, expected = read_data("long_strings__edge_case")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         black.assert_equivalent(source, actual)
518         black.assert_stable(source, actual, DEFAULT_MODE)
519
520     @patch("black.dump_to_file", dump_to_stderr)
521     def test_long_strings__regression(self) -> None:
522         """Regression tests for splitting long strings."""
523         source, expected = read_data("long_strings__regression")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, DEFAULT_MODE)
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_slices(self) -> None:
531         source, expected = read_data("slices")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, DEFAULT_MODE)
536
537     @patch("black.dump_to_file", dump_to_stderr)
538     def test_percent_precedence(self) -> None:
539         source, expected = read_data("percent_precedence")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         black.assert_equivalent(source, actual)
543         black.assert_stable(source, actual, DEFAULT_MODE)
544
545     @patch("black.dump_to_file", dump_to_stderr)
546     def test_comments(self) -> None:
547         source, expected = read_data("comments")
548         actual = fs(source)
549         self.assertFormatEqual(expected, actual)
550         black.assert_equivalent(source, actual)
551         black.assert_stable(source, actual, DEFAULT_MODE)
552
553     @patch("black.dump_to_file", dump_to_stderr)
554     def test_comments2(self) -> None:
555         source, expected = read_data("comments2")
556         actual = fs(source)
557         self.assertFormatEqual(expected, actual)
558         black.assert_equivalent(source, actual)
559         black.assert_stable(source, actual, DEFAULT_MODE)
560
561     @patch("black.dump_to_file", dump_to_stderr)
562     def test_comments3(self) -> None:
563         source, expected = read_data("comments3")
564         actual = fs(source)
565         self.assertFormatEqual(expected, actual)
566         black.assert_equivalent(source, actual)
567         black.assert_stable(source, actual, DEFAULT_MODE)
568
569     @patch("black.dump_to_file", dump_to_stderr)
570     def test_comments4(self) -> None:
571         source, expected = read_data("comments4")
572         actual = fs(source)
573         self.assertFormatEqual(expected, actual)
574         black.assert_equivalent(source, actual)
575         black.assert_stable(source, actual, DEFAULT_MODE)
576
577     @patch("black.dump_to_file", dump_to_stderr)
578     def test_comments5(self) -> None:
579         source, expected = read_data("comments5")
580         actual = fs(source)
581         self.assertFormatEqual(expected, actual)
582         black.assert_equivalent(source, actual)
583         black.assert_stable(source, actual, DEFAULT_MODE)
584
585     @patch("black.dump_to_file", dump_to_stderr)
586     def test_comments6(self) -> None:
587         source, expected = read_data("comments6")
588         actual = fs(source)
589         self.assertFormatEqual(expected, actual)
590         black.assert_equivalent(source, actual)
591         black.assert_stable(source, actual, DEFAULT_MODE)
592
593     @patch("black.dump_to_file", dump_to_stderr)
594     def test_comments7(self) -> None:
595         source, expected = read_data("comments7")
596         actual = fs(source)
597         self.assertFormatEqual(expected, actual)
598         black.assert_equivalent(source, actual)
599         black.assert_stable(source, actual, DEFAULT_MODE)
600
601     @patch("black.dump_to_file", dump_to_stderr)
602     def test_comment_after_escaped_newline(self) -> None:
603         source, expected = read_data("comment_after_escaped_newline")
604         actual = fs(source)
605         self.assertFormatEqual(expected, actual)
606         black.assert_equivalent(source, actual)
607         black.assert_stable(source, actual, DEFAULT_MODE)
608
609     @patch("black.dump_to_file", dump_to_stderr)
610     def test_cantfit(self) -> None:
611         source, expected = read_data("cantfit")
612         actual = fs(source)
613         self.assertFormatEqual(expected, actual)
614         black.assert_equivalent(source, actual)
615         black.assert_stable(source, actual, DEFAULT_MODE)
616
617     @patch("black.dump_to_file", dump_to_stderr)
618     def test_import_spacing(self) -> None:
619         source, expected = read_data("import_spacing")
620         actual = fs(source)
621         self.assertFormatEqual(expected, actual)
622         black.assert_equivalent(source, actual)
623         black.assert_stable(source, actual, DEFAULT_MODE)
624
625     @patch("black.dump_to_file", dump_to_stderr)
626     def test_composition(self) -> None:
627         source, expected = read_data("composition")
628         actual = fs(source)
629         self.assertFormatEqual(expected, actual)
630         black.assert_equivalent(source, actual)
631         black.assert_stable(source, actual, DEFAULT_MODE)
632
633     @patch("black.dump_to_file", dump_to_stderr)
634     def test_empty_lines(self) -> None:
635         source, expected = read_data("empty_lines")
636         actual = fs(source)
637         self.assertFormatEqual(expected, actual)
638         black.assert_equivalent(source, actual)
639         black.assert_stable(source, actual, DEFAULT_MODE)
640
641     @patch("black.dump_to_file", dump_to_stderr)
642     def test_remove_parens(self) -> None:
643         source, expected = read_data("remove_parens")
644         actual = fs(source)
645         self.assertFormatEqual(expected, actual)
646         black.assert_equivalent(source, actual)
647         black.assert_stable(source, actual, DEFAULT_MODE)
648
649     @patch("black.dump_to_file", dump_to_stderr)
650     def test_string_prefixes(self) -> None:
651         source, expected = read_data("string_prefixes")
652         actual = fs(source)
653         self.assertFormatEqual(expected, actual)
654         black.assert_equivalent(source, actual)
655         black.assert_stable(source, actual, DEFAULT_MODE)
656
657     @patch("black.dump_to_file", dump_to_stderr)
658     def test_numeric_literals(self) -> None:
659         source, expected = read_data("numeric_literals")
660         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
661         actual = fs(source, mode=mode)
662         self.assertFormatEqual(expected, actual)
663         black.assert_equivalent(source, actual)
664         black.assert_stable(source, actual, mode)
665
666     @patch("black.dump_to_file", dump_to_stderr)
667     def test_numeric_literals_ignoring_underscores(self) -> None:
668         source, expected = read_data("numeric_literals_skip_underscores")
669         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
670         actual = fs(source, mode=mode)
671         self.assertFormatEqual(expected, actual)
672         black.assert_equivalent(source, actual)
673         black.assert_stable(source, actual, mode)
674
675     @patch("black.dump_to_file", dump_to_stderr)
676     def test_numeric_literals_py2(self) -> None:
677         source, expected = read_data("numeric_literals_py2")
678         actual = fs(source)
679         self.assertFormatEqual(expected, actual)
680         black.assert_stable(source, actual, DEFAULT_MODE)
681
682     @patch("black.dump_to_file", dump_to_stderr)
683     def test_python2(self) -> None:
684         source, expected = read_data("python2")
685         actual = fs(source)
686         self.assertFormatEqual(expected, actual)
687         black.assert_equivalent(source, actual)
688         black.assert_stable(source, actual, DEFAULT_MODE)
689
690     @patch("black.dump_to_file", dump_to_stderr)
691     def test_python2_print_function(self) -> None:
692         source, expected = read_data("python2_print_function")
693         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
694         actual = fs(source, mode=mode)
695         self.assertFormatEqual(expected, actual)
696         black.assert_equivalent(source, actual)
697         black.assert_stable(source, actual, mode)
698
699     @patch("black.dump_to_file", dump_to_stderr)
700     def test_python2_unicode_literals(self) -> None:
701         source, expected = read_data("python2_unicode_literals")
702         actual = fs(source)
703         self.assertFormatEqual(expected, actual)
704         black.assert_equivalent(source, actual)
705         black.assert_stable(source, actual, DEFAULT_MODE)
706
707     @patch("black.dump_to_file", dump_to_stderr)
708     def test_stub(self) -> None:
709         mode = replace(DEFAULT_MODE, is_pyi=True)
710         source, expected = read_data("stub.pyi")
711         actual = fs(source, mode=mode)
712         self.assertFormatEqual(expected, actual)
713         black.assert_stable(source, actual, mode)
714
715     @patch("black.dump_to_file", dump_to_stderr)
716     def test_async_as_identifier(self) -> None:
717         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
718         source, expected = read_data("async_as_identifier")
719         actual = fs(source)
720         self.assertFormatEqual(expected, actual)
721         major, minor = sys.version_info[:2]
722         if major < 3 or (major <= 3 and minor < 7):
723             black.assert_equivalent(source, actual)
724         black.assert_stable(source, actual, DEFAULT_MODE)
725         # ensure black can parse this when the target is 3.6
726         self.invokeBlack([str(source_path), "--target-version", "py36"])
727         # but not on 3.7, because async/await is no longer an identifier
728         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
729
730     @patch("black.dump_to_file", dump_to_stderr)
731     def test_python37(self) -> None:
732         source_path = (THIS_DIR / "data" / "python37.py").resolve()
733         source, expected = read_data("python37")
734         actual = fs(source)
735         self.assertFormatEqual(expected, actual)
736         major, minor = sys.version_info[:2]
737         if major > 3 or (major == 3 and minor >= 7):
738             black.assert_equivalent(source, actual)
739         black.assert_stable(source, actual, DEFAULT_MODE)
740         # ensure black can parse this when the target is 3.7
741         self.invokeBlack([str(source_path), "--target-version", "py37"])
742         # but not on 3.6, because we use async as a reserved keyword
743         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
744
745     @patch("black.dump_to_file", dump_to_stderr)
746     def test_python38(self) -> None:
747         source, expected = read_data("python38")
748         actual = fs(source)
749         self.assertFormatEqual(expected, actual)
750         major, minor = sys.version_info[:2]
751         if major > 3 or (major == 3 and minor >= 8):
752             black.assert_equivalent(source, actual)
753         black.assert_stable(source, actual, DEFAULT_MODE)
754
755     @patch("black.dump_to_file", dump_to_stderr)
756     def test_fmtonoff(self) -> None:
757         source, expected = read_data("fmtonoff")
758         actual = fs(source)
759         self.assertFormatEqual(expected, actual)
760         black.assert_equivalent(source, actual)
761         black.assert_stable(source, actual, DEFAULT_MODE)
762
763     @patch("black.dump_to_file", dump_to_stderr)
764     def test_fmtonoff2(self) -> None:
765         source, expected = read_data("fmtonoff2")
766         actual = fs(source)
767         self.assertFormatEqual(expected, actual)
768         black.assert_equivalent(source, actual)
769         black.assert_stable(source, actual, DEFAULT_MODE)
770
771     @patch("black.dump_to_file", dump_to_stderr)
772     def test_fmtonoff3(self) -> None:
773         source, expected = read_data("fmtonoff3")
774         actual = fs(source)
775         self.assertFormatEqual(expected, actual)
776         black.assert_equivalent(source, actual)
777         black.assert_stable(source, actual, DEFAULT_MODE)
778
779     @patch("black.dump_to_file", dump_to_stderr)
780     def test_fmtonoff4(self) -> None:
781         source, expected = read_data("fmtonoff4")
782         actual = fs(source)
783         self.assertFormatEqual(expected, actual)
784         black.assert_equivalent(source, actual)
785         black.assert_stable(source, actual, DEFAULT_MODE)
786
787     @patch("black.dump_to_file", dump_to_stderr)
788     def test_remove_empty_parentheses_after_class(self) -> None:
789         source, expected = read_data("class_blank_parentheses")
790         actual = fs(source)
791         self.assertFormatEqual(expected, actual)
792         black.assert_equivalent(source, actual)
793         black.assert_stable(source, actual, DEFAULT_MODE)
794
795     @patch("black.dump_to_file", dump_to_stderr)
796     def test_new_line_between_class_and_code(self) -> None:
797         source, expected = read_data("class_methods_new_line")
798         actual = fs(source)
799         self.assertFormatEqual(expected, actual)
800         black.assert_equivalent(source, actual)
801         black.assert_stable(source, actual, DEFAULT_MODE)
802
803     @patch("black.dump_to_file", dump_to_stderr)
804     def test_bracket_match(self) -> None:
805         source, expected = read_data("bracketmatch")
806         actual = fs(source)
807         self.assertFormatEqual(expected, actual)
808         black.assert_equivalent(source, actual)
809         black.assert_stable(source, actual, DEFAULT_MODE)
810
811     @patch("black.dump_to_file", dump_to_stderr)
812     def test_tuple_assign(self) -> None:
813         source, expected = read_data("tupleassign")
814         actual = fs(source)
815         self.assertFormatEqual(expected, actual)
816         black.assert_equivalent(source, actual)
817         black.assert_stable(source, actual, DEFAULT_MODE)
818
819     @patch("black.dump_to_file", dump_to_stderr)
820     def test_beginning_backslash(self) -> None:
821         source, expected = read_data("beginning_backslash")
822         actual = fs(source)
823         self.assertFormatEqual(expected, actual)
824         black.assert_equivalent(source, actual)
825         black.assert_stable(source, actual, DEFAULT_MODE)
826
827     def test_tab_comment_indentation(self) -> None:
828         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
829         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
830         self.assertFormatEqual(contents_spc, fs(contents_spc))
831         self.assertFormatEqual(contents_spc, fs(contents_tab))
832
833         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
834         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
835         self.assertFormatEqual(contents_spc, fs(contents_spc))
836         self.assertFormatEqual(contents_spc, fs(contents_tab))
837
838         # mixed tabs and spaces (valid Python 2 code)
839         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
840         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
841         self.assertFormatEqual(contents_spc, fs(contents_spc))
842         self.assertFormatEqual(contents_spc, fs(contents_tab))
843
844         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
845         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
846         self.assertFormatEqual(contents_spc, fs(contents_spc))
847         self.assertFormatEqual(contents_spc, fs(contents_tab))
848
849     def test_report_verbose(self) -> None:
850         report = black.Report(verbose=True)
851         out_lines = []
852         err_lines = []
853
854         def out(msg: str, **kwargs: Any) -> None:
855             out_lines.append(msg)
856
857         def err(msg: str, **kwargs: Any) -> None:
858             err_lines.append(msg)
859
860         with patch("black.out", out), patch("black.err", err):
861             report.done(Path("f1"), black.Changed.NO)
862             self.assertEqual(len(out_lines), 1)
863             self.assertEqual(len(err_lines), 0)
864             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
865             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
866             self.assertEqual(report.return_code, 0)
867             report.done(Path("f2"), black.Changed.YES)
868             self.assertEqual(len(out_lines), 2)
869             self.assertEqual(len(err_lines), 0)
870             self.assertEqual(out_lines[-1], "reformatted f2")
871             self.assertEqual(
872                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
873             )
874             report.done(Path("f3"), black.Changed.CACHED)
875             self.assertEqual(len(out_lines), 3)
876             self.assertEqual(len(err_lines), 0)
877             self.assertEqual(
878                 out_lines[-1], "f3 wasn't modified on disk since last run."
879             )
880             self.assertEqual(
881                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
882             )
883             self.assertEqual(report.return_code, 0)
884             report.check = True
885             self.assertEqual(report.return_code, 1)
886             report.check = False
887             report.failed(Path("e1"), "boom")
888             self.assertEqual(len(out_lines), 3)
889             self.assertEqual(len(err_lines), 1)
890             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
891             self.assertEqual(
892                 unstyle(str(report)),
893                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
894                 " reformat.",
895             )
896             self.assertEqual(report.return_code, 123)
897             report.done(Path("f3"), black.Changed.YES)
898             self.assertEqual(len(out_lines), 4)
899             self.assertEqual(len(err_lines), 1)
900             self.assertEqual(out_lines[-1], "reformatted f3")
901             self.assertEqual(
902                 unstyle(str(report)),
903                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
904                 " reformat.",
905             )
906             self.assertEqual(report.return_code, 123)
907             report.failed(Path("e2"), "boom")
908             self.assertEqual(len(out_lines), 4)
909             self.assertEqual(len(err_lines), 2)
910             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
911             self.assertEqual(
912                 unstyle(str(report)),
913                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
914                 " reformat.",
915             )
916             self.assertEqual(report.return_code, 123)
917             report.path_ignored(Path("wat"), "no match")
918             self.assertEqual(len(out_lines), 5)
919             self.assertEqual(len(err_lines), 2)
920             self.assertEqual(out_lines[-1], "wat ignored: no match")
921             self.assertEqual(
922                 unstyle(str(report)),
923                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
924                 " reformat.",
925             )
926             self.assertEqual(report.return_code, 123)
927             report.done(Path("f4"), black.Changed.NO)
928             self.assertEqual(len(out_lines), 6)
929             self.assertEqual(len(err_lines), 2)
930             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
931             self.assertEqual(
932                 unstyle(str(report)),
933                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
934                 " reformat.",
935             )
936             self.assertEqual(report.return_code, 123)
937             report.check = True
938             self.assertEqual(
939                 unstyle(str(report)),
940                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
941                 " would fail to reformat.",
942             )
943             report.check = False
944             report.diff = True
945             self.assertEqual(
946                 unstyle(str(report)),
947                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
948                 " would fail to reformat.",
949             )
950
951     def test_report_quiet(self) -> None:
952         report = black.Report(quiet=True)
953         out_lines = []
954         err_lines = []
955
956         def out(msg: str, **kwargs: Any) -> None:
957             out_lines.append(msg)
958
959         def err(msg: str, **kwargs: Any) -> None:
960             err_lines.append(msg)
961
962         with patch("black.out", out), patch("black.err", err):
963             report.done(Path("f1"), black.Changed.NO)
964             self.assertEqual(len(out_lines), 0)
965             self.assertEqual(len(err_lines), 0)
966             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
967             self.assertEqual(report.return_code, 0)
968             report.done(Path("f2"), black.Changed.YES)
969             self.assertEqual(len(out_lines), 0)
970             self.assertEqual(len(err_lines), 0)
971             self.assertEqual(
972                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
973             )
974             report.done(Path("f3"), black.Changed.CACHED)
975             self.assertEqual(len(out_lines), 0)
976             self.assertEqual(len(err_lines), 0)
977             self.assertEqual(
978                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
979             )
980             self.assertEqual(report.return_code, 0)
981             report.check = True
982             self.assertEqual(report.return_code, 1)
983             report.check = False
984             report.failed(Path("e1"), "boom")
985             self.assertEqual(len(out_lines), 0)
986             self.assertEqual(len(err_lines), 1)
987             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
988             self.assertEqual(
989                 unstyle(str(report)),
990                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
991                 " reformat.",
992             )
993             self.assertEqual(report.return_code, 123)
994             report.done(Path("f3"), black.Changed.YES)
995             self.assertEqual(len(out_lines), 0)
996             self.assertEqual(len(err_lines), 1)
997             self.assertEqual(
998                 unstyle(str(report)),
999                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1000                 " reformat.",
1001             )
1002             self.assertEqual(report.return_code, 123)
1003             report.failed(Path("e2"), "boom")
1004             self.assertEqual(len(out_lines), 0)
1005             self.assertEqual(len(err_lines), 2)
1006             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1007             self.assertEqual(
1008                 unstyle(str(report)),
1009                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1010                 " reformat.",
1011             )
1012             self.assertEqual(report.return_code, 123)
1013             report.path_ignored(Path("wat"), "no match")
1014             self.assertEqual(len(out_lines), 0)
1015             self.assertEqual(len(err_lines), 2)
1016             self.assertEqual(
1017                 unstyle(str(report)),
1018                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1019                 " reformat.",
1020             )
1021             self.assertEqual(report.return_code, 123)
1022             report.done(Path("f4"), black.Changed.NO)
1023             self.assertEqual(len(out_lines), 0)
1024             self.assertEqual(len(err_lines), 2)
1025             self.assertEqual(
1026                 unstyle(str(report)),
1027                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1028                 " reformat.",
1029             )
1030             self.assertEqual(report.return_code, 123)
1031             report.check = True
1032             self.assertEqual(
1033                 unstyle(str(report)),
1034                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1035                 " would fail to reformat.",
1036             )
1037             report.check = False
1038             report.diff = True
1039             self.assertEqual(
1040                 unstyle(str(report)),
1041                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1042                 " would fail to reformat.",
1043             )
1044
1045     def test_report_normal(self) -> None:
1046         report = black.Report()
1047         out_lines = []
1048         err_lines = []
1049
1050         def out(msg: str, **kwargs: Any) -> None:
1051             out_lines.append(msg)
1052
1053         def err(msg: str, **kwargs: Any) -> None:
1054             err_lines.append(msg)
1055
1056         with patch("black.out", out), patch("black.err", err):
1057             report.done(Path("f1"), black.Changed.NO)
1058             self.assertEqual(len(out_lines), 0)
1059             self.assertEqual(len(err_lines), 0)
1060             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1061             self.assertEqual(report.return_code, 0)
1062             report.done(Path("f2"), black.Changed.YES)
1063             self.assertEqual(len(out_lines), 1)
1064             self.assertEqual(len(err_lines), 0)
1065             self.assertEqual(out_lines[-1], "reformatted f2")
1066             self.assertEqual(
1067                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1068             )
1069             report.done(Path("f3"), black.Changed.CACHED)
1070             self.assertEqual(len(out_lines), 1)
1071             self.assertEqual(len(err_lines), 0)
1072             self.assertEqual(out_lines[-1], "reformatted f2")
1073             self.assertEqual(
1074                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1075             )
1076             self.assertEqual(report.return_code, 0)
1077             report.check = True
1078             self.assertEqual(report.return_code, 1)
1079             report.check = False
1080             report.failed(Path("e1"), "boom")
1081             self.assertEqual(len(out_lines), 1)
1082             self.assertEqual(len(err_lines), 1)
1083             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1084             self.assertEqual(
1085                 unstyle(str(report)),
1086                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1087                 " reformat.",
1088             )
1089             self.assertEqual(report.return_code, 123)
1090             report.done(Path("f3"), black.Changed.YES)
1091             self.assertEqual(len(out_lines), 2)
1092             self.assertEqual(len(err_lines), 1)
1093             self.assertEqual(out_lines[-1], "reformatted f3")
1094             self.assertEqual(
1095                 unstyle(str(report)),
1096                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1097                 " reformat.",
1098             )
1099             self.assertEqual(report.return_code, 123)
1100             report.failed(Path("e2"), "boom")
1101             self.assertEqual(len(out_lines), 2)
1102             self.assertEqual(len(err_lines), 2)
1103             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1104             self.assertEqual(
1105                 unstyle(str(report)),
1106                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1107                 " reformat.",
1108             )
1109             self.assertEqual(report.return_code, 123)
1110             report.path_ignored(Path("wat"), "no match")
1111             self.assertEqual(len(out_lines), 2)
1112             self.assertEqual(len(err_lines), 2)
1113             self.assertEqual(
1114                 unstyle(str(report)),
1115                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1116                 " reformat.",
1117             )
1118             self.assertEqual(report.return_code, 123)
1119             report.done(Path("f4"), black.Changed.NO)
1120             self.assertEqual(len(out_lines), 2)
1121             self.assertEqual(len(err_lines), 2)
1122             self.assertEqual(
1123                 unstyle(str(report)),
1124                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1125                 " reformat.",
1126             )
1127             self.assertEqual(report.return_code, 123)
1128             report.check = True
1129             self.assertEqual(
1130                 unstyle(str(report)),
1131                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1132                 " would fail to reformat.",
1133             )
1134             report.check = False
1135             report.diff = True
1136             self.assertEqual(
1137                 unstyle(str(report)),
1138                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1139                 " would fail to reformat.",
1140             )
1141
1142     def test_lib2to3_parse(self) -> None:
1143         with self.assertRaises(black.InvalidInput):
1144             black.lib2to3_parse("invalid syntax")
1145
1146         straddling = "x + y"
1147         black.lib2to3_parse(straddling)
1148         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1149         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1150         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1151
1152         py2_only = "print x"
1153         black.lib2to3_parse(py2_only)
1154         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1155         with self.assertRaises(black.InvalidInput):
1156             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1157         with self.assertRaises(black.InvalidInput):
1158             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1159
1160         py3_only = "exec(x, end=y)"
1161         black.lib2to3_parse(py3_only)
1162         with self.assertRaises(black.InvalidInput):
1163             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1164         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1165         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1166
1167     def test_get_features_used(self) -> None:
1168         node = black.lib2to3_parse("def f(*, arg): ...\n")
1169         self.assertEqual(black.get_features_used(node), set())
1170         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1171         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1172         node = black.lib2to3_parse("f(*arg,)\n")
1173         self.assertEqual(
1174             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1175         )
1176         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1177         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1178         node = black.lib2to3_parse("123_456\n")
1179         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1180         node = black.lib2to3_parse("123456\n")
1181         self.assertEqual(black.get_features_used(node), set())
1182         source, expected = read_data("function")
1183         node = black.lib2to3_parse(source)
1184         expected_features = {
1185             Feature.TRAILING_COMMA_IN_CALL,
1186             Feature.TRAILING_COMMA_IN_DEF,
1187             Feature.F_STRINGS,
1188         }
1189         self.assertEqual(black.get_features_used(node), expected_features)
1190         node = black.lib2to3_parse(expected)
1191         self.assertEqual(black.get_features_used(node), expected_features)
1192         source, expected = read_data("expression")
1193         node = black.lib2to3_parse(source)
1194         self.assertEqual(black.get_features_used(node), set())
1195         node = black.lib2to3_parse(expected)
1196         self.assertEqual(black.get_features_used(node), set())
1197
1198     def test_get_future_imports(self) -> None:
1199         node = black.lib2to3_parse("\n")
1200         self.assertEqual(set(), black.get_future_imports(node))
1201         node = black.lib2to3_parse("from __future__ import black\n")
1202         self.assertEqual({"black"}, black.get_future_imports(node))
1203         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1204         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1205         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1206         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1207         node = black.lib2to3_parse(
1208             "from __future__ import multiple\nfrom __future__ import imports\n"
1209         )
1210         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1211         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1212         self.assertEqual({"black"}, black.get_future_imports(node))
1213         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1214         self.assertEqual({"black"}, black.get_future_imports(node))
1215         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1216         self.assertEqual(set(), black.get_future_imports(node))
1217         node = black.lib2to3_parse("from some.module import black\n")
1218         self.assertEqual(set(), black.get_future_imports(node))
1219         node = black.lib2to3_parse(
1220             "from __future__ import unicode_literals as _unicode_literals"
1221         )
1222         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1223         node = black.lib2to3_parse(
1224             "from __future__ import unicode_literals as _lol, print"
1225         )
1226         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1227
1228     def test_debug_visitor(self) -> None:
1229         source, _ = read_data("debug_visitor.py")
1230         expected, _ = read_data("debug_visitor.out")
1231         out_lines = []
1232         err_lines = []
1233
1234         def out(msg: str, **kwargs: Any) -> None:
1235             out_lines.append(msg)
1236
1237         def err(msg: str, **kwargs: Any) -> None:
1238             err_lines.append(msg)
1239
1240         with patch("black.out", out), patch("black.err", err):
1241             black.DebugVisitor.show(source)
1242         actual = "\n".join(out_lines) + "\n"
1243         log_name = ""
1244         if expected != actual:
1245             log_name = black.dump_to_file(*out_lines)
1246         self.assertEqual(
1247             expected,
1248             actual,
1249             f"AST print out is different. Actual version dumped to {log_name}",
1250         )
1251
1252     def test_format_file_contents(self) -> None:
1253         empty = ""
1254         mode = DEFAULT_MODE
1255         with self.assertRaises(black.NothingChanged):
1256             black.format_file_contents(empty, mode=mode, fast=False)
1257         just_nl = "\n"
1258         with self.assertRaises(black.NothingChanged):
1259             black.format_file_contents(just_nl, mode=mode, fast=False)
1260         same = "j = [1, 2, 3]\n"
1261         with self.assertRaises(black.NothingChanged):
1262             black.format_file_contents(same, mode=mode, fast=False)
1263         different = "j = [1,2,3]"
1264         expected = same
1265         actual = black.format_file_contents(different, mode=mode, fast=False)
1266         self.assertEqual(expected, actual)
1267         invalid = "return if you can"
1268         with self.assertRaises(black.InvalidInput) as e:
1269             black.format_file_contents(invalid, mode=mode, fast=False)
1270         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1271
1272     def test_endmarker(self) -> None:
1273         n = black.lib2to3_parse("\n")
1274         self.assertEqual(n.type, black.syms.file_input)
1275         self.assertEqual(len(n.children), 1)
1276         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1277
1278     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1279     def test_assertFormatEqual(self) -> None:
1280         out_lines = []
1281         err_lines = []
1282
1283         def out(msg: str, **kwargs: Any) -> None:
1284             out_lines.append(msg)
1285
1286         def err(msg: str, **kwargs: Any) -> None:
1287             err_lines.append(msg)
1288
1289         with patch("black.out", out), patch("black.err", err):
1290             with self.assertRaises(AssertionError):
1291                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1292
1293         out_str = "".join(out_lines)
1294         self.assertTrue("Expected tree:" in out_str)
1295         self.assertTrue("Actual tree:" in out_str)
1296         self.assertEqual("".join(err_lines), "")
1297
1298     def test_cache_broken_file(self) -> None:
1299         mode = DEFAULT_MODE
1300         with cache_dir() as workspace:
1301             cache_file = black.get_cache_file(mode)
1302             with cache_file.open("w") as fobj:
1303                 fobj.write("this is not a pickle")
1304             self.assertEqual(black.read_cache(mode), {})
1305             src = (workspace / "test.py").resolve()
1306             with src.open("w") as fobj:
1307                 fobj.write("print('hello')")
1308             self.invokeBlack([str(src)])
1309             cache = black.read_cache(mode)
1310             self.assertIn(src, cache)
1311
1312     def test_cache_single_file_already_cached(self) -> None:
1313         mode = DEFAULT_MODE
1314         with cache_dir() as workspace:
1315             src = (workspace / "test.py").resolve()
1316             with src.open("w") as fobj:
1317                 fobj.write("print('hello')")
1318             black.write_cache({}, [src], mode)
1319             self.invokeBlack([str(src)])
1320             with src.open("r") as fobj:
1321                 self.assertEqual(fobj.read(), "print('hello')")
1322
1323     @event_loop()
1324     def test_cache_multiple_files(self) -> None:
1325         mode = DEFAULT_MODE
1326         with cache_dir() as workspace, patch(
1327             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1328         ):
1329             one = (workspace / "one.py").resolve()
1330             with one.open("w") as fobj:
1331                 fobj.write("print('hello')")
1332             two = (workspace / "two.py").resolve()
1333             with two.open("w") as fobj:
1334                 fobj.write("print('hello')")
1335             black.write_cache({}, [one], mode)
1336             self.invokeBlack([str(workspace)])
1337             with one.open("r") as fobj:
1338                 self.assertEqual(fobj.read(), "print('hello')")
1339             with two.open("r") as fobj:
1340                 self.assertEqual(fobj.read(), 'print("hello")\n')
1341             cache = black.read_cache(mode)
1342             self.assertIn(one, cache)
1343             self.assertIn(two, cache)
1344
1345     def test_no_cache_when_writeback_diff(self) -> None:
1346         mode = DEFAULT_MODE
1347         with cache_dir() as workspace:
1348             src = (workspace / "test.py").resolve()
1349             with src.open("w") as fobj:
1350                 fobj.write("print('hello')")
1351             self.invokeBlack([str(src), "--diff"])
1352             cache_file = black.get_cache_file(mode)
1353             self.assertFalse(cache_file.exists())
1354
1355     def test_no_cache_when_stdin(self) -> None:
1356         mode = DEFAULT_MODE
1357         with cache_dir():
1358             result = CliRunner().invoke(
1359                 black.main, ["-"], input=BytesIO(b"print('hello')")
1360             )
1361             self.assertEqual(result.exit_code, 0)
1362             cache_file = black.get_cache_file(mode)
1363             self.assertFalse(cache_file.exists())
1364
1365     def test_read_cache_no_cachefile(self) -> None:
1366         mode = DEFAULT_MODE
1367         with cache_dir():
1368             self.assertEqual(black.read_cache(mode), {})
1369
1370     def test_write_cache_read_cache(self) -> None:
1371         mode = DEFAULT_MODE
1372         with cache_dir() as workspace:
1373             src = (workspace / "test.py").resolve()
1374             src.touch()
1375             black.write_cache({}, [src], mode)
1376             cache = black.read_cache(mode)
1377             self.assertIn(src, cache)
1378             self.assertEqual(cache[src], black.get_cache_info(src))
1379
1380     def test_filter_cached(self) -> None:
1381         with TemporaryDirectory() as workspace:
1382             path = Path(workspace)
1383             uncached = (path / "uncached").resolve()
1384             cached = (path / "cached").resolve()
1385             cached_but_changed = (path / "changed").resolve()
1386             uncached.touch()
1387             cached.touch()
1388             cached_but_changed.touch()
1389             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1390             todo, done = black.filter_cached(
1391                 cache, {uncached, cached, cached_but_changed}
1392             )
1393             self.assertEqual(todo, {uncached, cached_but_changed})
1394             self.assertEqual(done, {cached})
1395
1396     def test_write_cache_creates_directory_if_needed(self) -> None:
1397         mode = DEFAULT_MODE
1398         with cache_dir(exists=False) as workspace:
1399             self.assertFalse(workspace.exists())
1400             black.write_cache({}, [], mode)
1401             self.assertTrue(workspace.exists())
1402
1403     @event_loop()
1404     def test_failed_formatting_does_not_get_cached(self) -> None:
1405         mode = DEFAULT_MODE
1406         with cache_dir() as workspace, patch(
1407             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1408         ):
1409             failing = (workspace / "failing.py").resolve()
1410             with failing.open("w") as fobj:
1411                 fobj.write("not actually python")
1412             clean = (workspace / "clean.py").resolve()
1413             with clean.open("w") as fobj:
1414                 fobj.write('print("hello")\n')
1415             self.invokeBlack([str(workspace)], exit_code=123)
1416             cache = black.read_cache(mode)
1417             self.assertNotIn(failing, cache)
1418             self.assertIn(clean, cache)
1419
1420     def test_write_cache_write_fail(self) -> None:
1421         mode = DEFAULT_MODE
1422         with cache_dir(), patch.object(Path, "open") as mock:
1423             mock.side_effect = OSError
1424             black.write_cache({}, [], mode)
1425
1426     @event_loop()
1427     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1428     def test_works_in_mono_process_only_environment(self) -> None:
1429         with cache_dir() as workspace:
1430             for f in [
1431                 (workspace / "one.py").resolve(),
1432                 (workspace / "two.py").resolve(),
1433             ]:
1434                 f.write_text('print("hello")\n')
1435             self.invokeBlack([str(workspace)])
1436
1437     @event_loop()
1438     def test_check_diff_use_together(self) -> None:
1439         with cache_dir():
1440             # Files which will be reformatted.
1441             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1442             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1443             # Files which will not be reformatted.
1444             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1445             self.invokeBlack([str(src2), "--diff", "--check"])
1446             # Multi file command.
1447             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1448
1449     def test_no_files(self) -> None:
1450         with cache_dir():
1451             # Without an argument, black exits with error code 0.
1452             self.invokeBlack([])
1453
1454     def test_broken_symlink(self) -> None:
1455         with cache_dir() as workspace:
1456             symlink = workspace / "broken_link.py"
1457             try:
1458                 symlink.symlink_to("nonexistent.py")
1459             except OSError as e:
1460                 self.skipTest(f"Can't create symlinks: {e}")
1461             self.invokeBlack([str(workspace.resolve())])
1462
1463     def test_read_cache_line_lengths(self) -> None:
1464         mode = DEFAULT_MODE
1465         short_mode = replace(DEFAULT_MODE, line_length=1)
1466         with cache_dir() as workspace:
1467             path = (workspace / "file.py").resolve()
1468             path.touch()
1469             black.write_cache({}, [path], mode)
1470             one = black.read_cache(mode)
1471             self.assertIn(path, one)
1472             two = black.read_cache(short_mode)
1473             self.assertNotIn(path, two)
1474
1475     def test_tricky_unicode_symbols(self) -> None:
1476         source, expected = read_data("tricky_unicode_symbols")
1477         actual = fs(source)
1478         self.assertFormatEqual(expected, actual)
1479         black.assert_equivalent(source, actual)
1480         black.assert_stable(source, actual, DEFAULT_MODE)
1481
1482     def test_single_file_force_pyi(self) -> None:
1483         reg_mode = DEFAULT_MODE
1484         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1485         contents, expected = read_data("force_pyi")
1486         with cache_dir() as workspace:
1487             path = (workspace / "file.py").resolve()
1488             with open(path, "w") as fh:
1489                 fh.write(contents)
1490             self.invokeBlack([str(path), "--pyi"])
1491             with open(path, "r") as fh:
1492                 actual = fh.read()
1493             # verify cache with --pyi is separate
1494             pyi_cache = black.read_cache(pyi_mode)
1495             self.assertIn(path, pyi_cache)
1496             normal_cache = black.read_cache(reg_mode)
1497             self.assertNotIn(path, normal_cache)
1498         self.assertEqual(actual, expected)
1499
1500     @event_loop()
1501     def test_multi_file_force_pyi(self) -> None:
1502         reg_mode = DEFAULT_MODE
1503         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1504         contents, expected = read_data("force_pyi")
1505         with cache_dir() as workspace:
1506             paths = [
1507                 (workspace / "file1.py").resolve(),
1508                 (workspace / "file2.py").resolve(),
1509             ]
1510             for path in paths:
1511                 with open(path, "w") as fh:
1512                     fh.write(contents)
1513             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1514             for path in paths:
1515                 with open(path, "r") as fh:
1516                     actual = fh.read()
1517                 self.assertEqual(actual, expected)
1518             # verify cache with --pyi is separate
1519             pyi_cache = black.read_cache(pyi_mode)
1520             normal_cache = black.read_cache(reg_mode)
1521             for path in paths:
1522                 self.assertIn(path, pyi_cache)
1523                 self.assertNotIn(path, normal_cache)
1524
1525     def test_pipe_force_pyi(self) -> None:
1526         source, expected = read_data("force_pyi")
1527         result = CliRunner().invoke(
1528             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1529         )
1530         self.assertEqual(result.exit_code, 0)
1531         actual = result.output
1532         self.assertFormatEqual(actual, expected)
1533
1534     def test_single_file_force_py36(self) -> None:
1535         reg_mode = DEFAULT_MODE
1536         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1537         source, expected = read_data("force_py36")
1538         with cache_dir() as workspace:
1539             path = (workspace / "file.py").resolve()
1540             with open(path, "w") as fh:
1541                 fh.write(source)
1542             self.invokeBlack([str(path), *PY36_ARGS])
1543             with open(path, "r") as fh:
1544                 actual = fh.read()
1545             # verify cache with --target-version is separate
1546             py36_cache = black.read_cache(py36_mode)
1547             self.assertIn(path, py36_cache)
1548             normal_cache = black.read_cache(reg_mode)
1549             self.assertNotIn(path, normal_cache)
1550         self.assertEqual(actual, expected)
1551
1552     @event_loop()
1553     def test_multi_file_force_py36(self) -> None:
1554         reg_mode = DEFAULT_MODE
1555         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1556         source, expected = read_data("force_py36")
1557         with cache_dir() as workspace:
1558             paths = [
1559                 (workspace / "file1.py").resolve(),
1560                 (workspace / "file2.py").resolve(),
1561             ]
1562             for path in paths:
1563                 with open(path, "w") as fh:
1564                     fh.write(source)
1565             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1566             for path in paths:
1567                 with open(path, "r") as fh:
1568                     actual = fh.read()
1569                 self.assertEqual(actual, expected)
1570             # verify cache with --target-version is separate
1571             pyi_cache = black.read_cache(py36_mode)
1572             normal_cache = black.read_cache(reg_mode)
1573             for path in paths:
1574                 self.assertIn(path, pyi_cache)
1575                 self.assertNotIn(path, normal_cache)
1576
1577     def test_collections(self) -> None:
1578         source, expected = read_data("collections")
1579         actual = fs(source)
1580         self.assertFormatEqual(expected, actual)
1581         black.assert_equivalent(source, actual)
1582         black.assert_stable(source, actual, DEFAULT_MODE)
1583
1584     def test_pipe_force_py36(self) -> None:
1585         source, expected = read_data("force_py36")
1586         result = CliRunner().invoke(
1587             black.main,
1588             ["-", "-q", "--target-version=py36"],
1589             input=BytesIO(source.encode("utf8")),
1590         )
1591         self.assertEqual(result.exit_code, 0)
1592         actual = result.output
1593         self.assertFormatEqual(actual, expected)
1594
1595     def test_include_exclude(self) -> None:
1596         path = THIS_DIR / "data" / "include_exclude_tests"
1597         include = re.compile(r"\.pyi?$")
1598         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1599         report = black.Report()
1600         gitignore = PathSpec.from_lines("gitwildmatch", [])
1601         sources: List[Path] = []
1602         expected = [
1603             Path(path / "b/dont_exclude/a.py"),
1604             Path(path / "b/dont_exclude/a.pyi"),
1605         ]
1606         this_abs = THIS_DIR.resolve()
1607         sources.extend(
1608             black.gen_python_files(
1609                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1610             )
1611         )
1612         self.assertEqual(sorted(expected), sorted(sources))
1613
1614     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1615     def test_exclude_for_issue_1572(self) -> None:
1616         # Exclude shouldn't touch files that were explicitly given to Black through the
1617         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1618         # https://github.com/psf/black/issues/1572
1619         path = THIS_DIR / "data" / "include_exclude_tests"
1620         include = ""
1621         exclude = r"/exclude/|a\.py"
1622         src = str(path / "b/exclude/a.py")
1623         report = black.Report()
1624         expected = [Path(path / "b/exclude/a.py")]
1625         sources = list(
1626             black.get_sources(
1627                 ctx=FakeContext(),
1628                 src=(src,),
1629                 quiet=True,
1630                 verbose=False,
1631                 include=include,
1632                 exclude=exclude,
1633                 force_exclude=None,
1634                 report=report,
1635             )
1636         )
1637         self.assertEqual(sorted(expected), sorted(sources))
1638
1639     def test_gitignore_exclude(self) -> None:
1640         path = THIS_DIR / "data" / "include_exclude_tests"
1641         include = re.compile(r"\.pyi?$")
1642         exclude = re.compile(r"")
1643         report = black.Report()
1644         gitignore = PathSpec.from_lines(
1645             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1646         )
1647         sources: List[Path] = []
1648         expected = [
1649             Path(path / "b/dont_exclude/a.py"),
1650             Path(path / "b/dont_exclude/a.pyi"),
1651         ]
1652         this_abs = THIS_DIR.resolve()
1653         sources.extend(
1654             black.gen_python_files(
1655                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1656             )
1657         )
1658         self.assertEqual(sorted(expected), sorted(sources))
1659
1660     def test_empty_include(self) -> None:
1661         path = THIS_DIR / "data" / "include_exclude_tests"
1662         report = black.Report()
1663         gitignore = PathSpec.from_lines("gitwildmatch", [])
1664         empty = re.compile(r"")
1665         sources: List[Path] = []
1666         expected = [
1667             Path(path / "b/exclude/a.pie"),
1668             Path(path / "b/exclude/a.py"),
1669             Path(path / "b/exclude/a.pyi"),
1670             Path(path / "b/dont_exclude/a.pie"),
1671             Path(path / "b/dont_exclude/a.py"),
1672             Path(path / "b/dont_exclude/a.pyi"),
1673             Path(path / "b/.definitely_exclude/a.pie"),
1674             Path(path / "b/.definitely_exclude/a.py"),
1675             Path(path / "b/.definitely_exclude/a.pyi"),
1676         ]
1677         this_abs = THIS_DIR.resolve()
1678         sources.extend(
1679             black.gen_python_files(
1680                 path.iterdir(),
1681                 this_abs,
1682                 empty,
1683                 re.compile(black.DEFAULT_EXCLUDES),
1684                 None,
1685                 report,
1686                 gitignore,
1687             )
1688         )
1689         self.assertEqual(sorted(expected), sorted(sources))
1690
1691     def test_empty_exclude(self) -> None:
1692         path = THIS_DIR / "data" / "include_exclude_tests"
1693         report = black.Report()
1694         gitignore = PathSpec.from_lines("gitwildmatch", [])
1695         empty = re.compile(r"")
1696         sources: List[Path] = []
1697         expected = [
1698             Path(path / "b/dont_exclude/a.py"),
1699             Path(path / "b/dont_exclude/a.pyi"),
1700             Path(path / "b/exclude/a.py"),
1701             Path(path / "b/exclude/a.pyi"),
1702             Path(path / "b/.definitely_exclude/a.py"),
1703             Path(path / "b/.definitely_exclude/a.pyi"),
1704         ]
1705         this_abs = THIS_DIR.resolve()
1706         sources.extend(
1707             black.gen_python_files(
1708                 path.iterdir(),
1709                 this_abs,
1710                 re.compile(black.DEFAULT_INCLUDES),
1711                 empty,
1712                 None,
1713                 report,
1714                 gitignore,
1715             )
1716         )
1717         self.assertEqual(sorted(expected), sorted(sources))
1718
1719     def test_invalid_include_exclude(self) -> None:
1720         for option in ["--include", "--exclude"]:
1721             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1722
1723     def test_preserves_line_endings(self) -> None:
1724         with TemporaryDirectory() as workspace:
1725             test_file = Path(workspace) / "test.py"
1726             for nl in ["\n", "\r\n"]:
1727                 contents = nl.join(["def f(  ):", "    pass"])
1728                 test_file.write_bytes(contents.encode())
1729                 ff(test_file, write_back=black.WriteBack.YES)
1730                 updated_contents: bytes = test_file.read_bytes()
1731                 self.assertIn(nl.encode(), updated_contents)
1732                 if nl == "\n":
1733                     self.assertNotIn(b"\r\n", updated_contents)
1734
1735     def test_preserves_line_endings_via_stdin(self) -> None:
1736         for nl in ["\n", "\r\n"]:
1737             contents = nl.join(["def f(  ):", "    pass"])
1738             runner = BlackRunner()
1739             result = runner.invoke(
1740                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1741             )
1742             self.assertEqual(result.exit_code, 0)
1743             output = runner.stdout_bytes
1744             self.assertIn(nl.encode("utf8"), output)
1745             if nl == "\n":
1746                 self.assertNotIn(b"\r\n", output)
1747
1748     def test_assert_equivalent_different_asts(self) -> None:
1749         with self.assertRaises(AssertionError):
1750             black.assert_equivalent("{}", "None")
1751
1752     def test_symlink_out_of_root_directory(self) -> None:
1753         path = MagicMock()
1754         root = THIS_DIR.resolve()
1755         child = MagicMock()
1756         include = re.compile(black.DEFAULT_INCLUDES)
1757         exclude = re.compile(black.DEFAULT_EXCLUDES)
1758         report = black.Report()
1759         gitignore = PathSpec.from_lines("gitwildmatch", [])
1760         # `child` should behave like a symlink which resolved path is clearly
1761         # outside of the `root` directory.
1762         path.iterdir.return_value = [child]
1763         child.resolve.return_value = Path("/a/b/c")
1764         child.as_posix.return_value = "/a/b/c"
1765         child.is_symlink.return_value = True
1766         try:
1767             list(
1768                 black.gen_python_files(
1769                     path.iterdir(), root, include, exclude, None, report, gitignore
1770                 )
1771             )
1772         except ValueError as ve:
1773             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1774         path.iterdir.assert_called_once()
1775         child.resolve.assert_called_once()
1776         child.is_symlink.assert_called_once()
1777         # `child` should behave like a strange file which resolved path is clearly
1778         # outside of the `root` directory.
1779         child.is_symlink.return_value = False
1780         with self.assertRaises(ValueError):
1781             list(
1782                 black.gen_python_files(
1783                     path.iterdir(), root, include, exclude, None, report, gitignore
1784                 )
1785             )
1786         path.iterdir.assert_called()
1787         self.assertEqual(path.iterdir.call_count, 2)
1788         child.resolve.assert_called()
1789         self.assertEqual(child.resolve.call_count, 2)
1790         child.is_symlink.assert_called()
1791         self.assertEqual(child.is_symlink.call_count, 2)
1792
1793     def test_shhh_click(self) -> None:
1794         try:
1795             from click import _unicodefun  # type: ignore
1796         except ModuleNotFoundError:
1797             self.skipTest("Incompatible Click version")
1798         if not hasattr(_unicodefun, "_verify_python3_env"):
1799             self.skipTest("Incompatible Click version")
1800         # First, let's see if Click is crashing with a preferred ASCII charset.
1801         with patch("locale.getpreferredencoding") as gpe:
1802             gpe.return_value = "ASCII"
1803             with self.assertRaises(RuntimeError):
1804                 _unicodefun._verify_python3_env()
1805         # Now, let's silence Click...
1806         black.patch_click()
1807         # ...and confirm it's silent.
1808         with patch("locale.getpreferredencoding") as gpe:
1809             gpe.return_value = "ASCII"
1810             try:
1811                 _unicodefun._verify_python3_env()
1812             except RuntimeError as re:
1813                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1814
1815     def test_root_logger_not_used_directly(self) -> None:
1816         def fail(*args: Any, **kwargs: Any) -> None:
1817             self.fail("Record created with root logger")
1818
1819         with patch.multiple(
1820             logging.root,
1821             debug=fail,
1822             info=fail,
1823             warning=fail,
1824             error=fail,
1825             critical=fail,
1826             log=fail,
1827         ):
1828             ff(THIS_FILE)
1829
1830     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1831     def test_blackd_main(self) -> None:
1832         with patch("blackd.web.run_app"):
1833             result = CliRunner().invoke(blackd.main, [])
1834             if result.exception is not None:
1835                 raise result.exception
1836             self.assertEqual(result.exit_code, 0)
1837
1838     def test_invalid_config_return_code(self) -> None:
1839         tmp_file = Path(black.dump_to_file())
1840         try:
1841             tmp_config = Path(black.dump_to_file())
1842             tmp_config.unlink()
1843             args = ["--config", str(tmp_config), str(tmp_file)]
1844             self.invokeBlack(args, exit_code=2, ignore_config=False)
1845         finally:
1846             tmp_file.unlink()
1847
1848     def test_parse_pyproject_toml(self) -> None:
1849         test_toml_file = THIS_DIR / "test.toml"
1850         config = black.parse_pyproject_toml(str(test_toml_file))
1851         self.assertEqual(config["verbose"], 1)
1852         self.assertEqual(config["check"], "no")
1853         self.assertEqual(config["diff"], "y")
1854         self.assertEqual(config["color"], True)
1855         self.assertEqual(config["line_length"], 79)
1856         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1857         self.assertEqual(config["exclude"], r"\.pyi?$")
1858         self.assertEqual(config["include"], r"\.py?$")
1859
1860     def test_read_pyproject_toml(self) -> None:
1861         test_toml_file = THIS_DIR / "test.toml"
1862         fake_ctx = FakeContext()
1863         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1864         config = fake_ctx.default_map
1865         self.assertEqual(config["verbose"], "1")
1866         self.assertEqual(config["check"], "no")
1867         self.assertEqual(config["diff"], "y")
1868         self.assertEqual(config["color"], "True")
1869         self.assertEqual(config["line_length"], "79")
1870         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1871         self.assertEqual(config["exclude"], r"\.pyi?$")
1872         self.assertEqual(config["include"], r"\.py?$")
1873
1874     def test_find_project_root(self) -> None:
1875         with TemporaryDirectory() as workspace:
1876             root = Path(workspace)
1877             test_dir = root / "test"
1878             test_dir.mkdir()
1879
1880             src_dir = root / "src"
1881             src_dir.mkdir()
1882
1883             root_pyproject = root / "pyproject.toml"
1884             root_pyproject.touch()
1885             src_pyproject = src_dir / "pyproject.toml"
1886             src_pyproject.touch()
1887             src_python = src_dir / "foo.py"
1888             src_python.touch()
1889
1890             self.assertEqual(
1891                 black.find_project_root((src_dir, test_dir)), root.resolve()
1892             )
1893             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1894             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1895
1896
1897 class BlackDTestCase(AioHTTPTestCase):
1898     async def get_application(self) -> web.Application:
1899         return blackd.make_app()
1900
1901     # TODO: remove these decorators once the below is released
1902     # https://github.com/aio-libs/aiohttp/pull/3727
1903     @skip_if_exception("ClientOSError")
1904     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1905     @unittest_run_loop
1906     async def test_blackd_request_needs_formatting(self) -> None:
1907         response = await self.client.post("/", data=b"print('hello world')")
1908         self.assertEqual(response.status, 200)
1909         self.assertEqual(response.charset, "utf8")
1910         self.assertEqual(await response.read(), b'print("hello world")\n')
1911
1912     @skip_if_exception("ClientOSError")
1913     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1914     @unittest_run_loop
1915     async def test_blackd_request_no_change(self) -> None:
1916         response = await self.client.post("/", data=b'print("hello world")\n')
1917         self.assertEqual(response.status, 204)
1918         self.assertEqual(await response.read(), b"")
1919
1920     @skip_if_exception("ClientOSError")
1921     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1922     @unittest_run_loop
1923     async def test_blackd_request_syntax_error(self) -> None:
1924         response = await self.client.post("/", data=b"what even ( is")
1925         self.assertEqual(response.status, 400)
1926         content = await response.text()
1927         self.assertTrue(
1928             content.startswith("Cannot parse"),
1929             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1930         )
1931
1932     @skip_if_exception("ClientOSError")
1933     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1934     @unittest_run_loop
1935     async def test_blackd_unsupported_version(self) -> None:
1936         response = await self.client.post(
1937             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1938         )
1939         self.assertEqual(response.status, 501)
1940
1941     @skip_if_exception("ClientOSError")
1942     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1943     @unittest_run_loop
1944     async def test_blackd_supported_version(self) -> None:
1945         response = await self.client.post(
1946             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1947         )
1948         self.assertEqual(response.status, 200)
1949
1950     @skip_if_exception("ClientOSError")
1951     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1952     @unittest_run_loop
1953     async def test_blackd_invalid_python_variant(self) -> None:
1954         async def check(header_value: str, expected_status: int = 400) -> None:
1955             response = await self.client.post(
1956                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1957             )
1958             self.assertEqual(response.status, expected_status)
1959
1960         await check("lol")
1961         await check("ruby3.5")
1962         await check("pyi3.6")
1963         await check("py1.5")
1964         await check("2.8")
1965         await check("py2.8")
1966         await check("3.0")
1967         await check("pypy3.0")
1968         await check("jython3.4")
1969
1970     @skip_if_exception("ClientOSError")
1971     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1972     @unittest_run_loop
1973     async def test_blackd_pyi(self) -> None:
1974         source, expected = read_data("stub.pyi")
1975         response = await self.client.post(
1976             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1977         )
1978         self.assertEqual(response.status, 200)
1979         self.assertEqual(await response.text(), expected)
1980
1981     @skip_if_exception("ClientOSError")
1982     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1983     @unittest_run_loop
1984     async def test_blackd_diff(self) -> None:
1985         diff_header = re.compile(
1986             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1987         )
1988
1989         source, _ = read_data("blackd_diff.py")
1990         expected, _ = read_data("blackd_diff.diff")
1991
1992         response = await self.client.post(
1993             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
1994         )
1995         self.assertEqual(response.status, 200)
1996
1997         actual = await response.text()
1998         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1999         self.assertEqual(actual, expected)
2000
2001     @skip_if_exception("ClientOSError")
2002     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2003     @unittest_run_loop
2004     async def test_blackd_python_variant(self) -> None:
2005         code = (
2006             "def f(\n"
2007             "    and_has_a_bunch_of,\n"
2008             "    very_long_arguments_too,\n"
2009             "    and_lots_of_them_as_well_lol,\n"
2010             "    **and_very_long_keyword_arguments\n"
2011             "):\n"
2012             "    pass\n"
2013         )
2014
2015         async def check(header_value: str, expected_status: int) -> None:
2016             response = await self.client.post(
2017                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2018             )
2019             self.assertEqual(
2020                 response.status, expected_status, msg=await response.text()
2021             )
2022
2023         await check("3.6", 200)
2024         await check("py3.6", 200)
2025         await check("3.6,3.7", 200)
2026         await check("3.6,py3.7", 200)
2027         await check("py36,py37", 200)
2028         await check("36", 200)
2029         await check("3.6.4", 200)
2030
2031         await check("2", 204)
2032         await check("2.7", 204)
2033         await check("py2.7", 204)
2034         await check("3.4", 204)
2035         await check("py3.4", 204)
2036         await check("py34,py36", 204)
2037         await check("34", 204)
2038
2039     @skip_if_exception("ClientOSError")
2040     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2041     @unittest_run_loop
2042     async def test_blackd_line_length(self) -> None:
2043         response = await self.client.post(
2044             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2045         )
2046         self.assertEqual(response.status, 200)
2047
2048     @skip_if_exception("ClientOSError")
2049     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2050     @unittest_run_loop
2051     async def test_blackd_invalid_line_length(self) -> None:
2052         response = await self.client.post(
2053             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2054         )
2055         self.assertEqual(response.status, 400)
2056
2057     @skip_if_exception("ClientOSError")
2058     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2059     @unittest_run_loop
2060     async def test_blackd_response_black_version_header(self) -> None:
2061         response = await self.client.post("/")
2062         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2063
2064
2065 with open(black.__file__, "r", encoding="utf-8") as _bf:
2066     black_source_lines = _bf.readlines()
2067
2068
2069 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2070     """Show function calls `from black/__init__.py` as they happen.
2071
2072     Register this with `sys.settrace()` in a test you're debugging.
2073     """
2074     if event != "call":
2075         return tracefunc
2076
2077     stack = len(inspect.stack()) - 19
2078     filename = frame.f_code.co_filename
2079     lineno = frame.f_lineno
2080     func_sig_lineno = lineno - 1
2081     funcname = black_source_lines[func_sig_lineno].strip()
2082     while funcname.startswith("@"):
2083         func_sig_lineno += 1
2084         funcname = black_source_lines[func_sig_lineno].strip()
2085     if "black/__init__.py" in filename:
2086         print(f"{' ' * stack}{lineno}:{funcname}")
2087     return tracefunc
2088
2089
2090 if __name__ == "__main__":
2091     unittest.main(module="test_black")