]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

bb1929ebb3137b959d8475bb6e4b805479c85fd0
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 import regex as re
13 import sys
14 from tempfile import TemporaryDirectory
15 import types
16 from typing import (
17     Any,
18     BinaryIO,
19     Callable,
20     Dict,
21     Generator,
22     List,
23     Tuple,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 try:
38     import blackd
39     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
40     from aiohttp import web
41 except ImportError:
42     has_blackd_deps = False
43 else:
44     has_blackd_deps = True
45
46 from pathspec import PathSpec
47
48 # Import other test classes
49 from .test_primer import PrimerCLITests  # noqa: F401
50
51
52 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
53 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
54 fs = partial(black.format_str, mode=DEFAULT_MODE)
55 THIS_FILE = Path(__file__)
56 THIS_DIR = THIS_FILE.parent
57 PROJECT_ROOT = THIS_DIR.parent
58 DETERMINISTIC_HEADER = "[Deterministic header]"
59 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
60 PY36_ARGS = [
61     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
62 ]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 def dump_to_stderr(*output: str) -> str:
68     return "\n" + "\n".join(output) + "\n"
69
70
71 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
72     """read_data('test_name') -> 'input', 'output'"""
73     if not name.endswith((".py", ".pyi", ".out", ".diff")):
74         name += ".py"
75     _input: List[str] = []
76     _output: List[str] = []
77     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
78     with open(base_dir / name, "r", encoding="utf8") as test:
79         lines = test.readlines()
80     result = _input
81     for line in lines:
82         line = line.replace(EMPTY_LINE, "")
83         if line.rstrip() == "# output":
84             result = _output
85             continue
86
87         result.append(line)
88     if _input and not _output:
89         # If there's no output marker, treat the entire file as already pre-formatted.
90         _output = _input[:]
91     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
92
93
94 @contextmanager
95 def cache_dir(exists: bool = True) -> Iterator[Path]:
96     with TemporaryDirectory() as workspace:
97         cache_dir = Path(workspace)
98         if not exists:
99             cache_dir = cache_dir / "new"
100         with patch("black.CACHE_DIR", cache_dir):
101             yield cache_dir
102
103
104 @contextmanager
105 def event_loop() -> Iterator[None]:
106     policy = asyncio.get_event_loop_policy()
107     loop = policy.new_event_loop()
108     asyncio.set_event_loop(loop)
109     try:
110         yield
111
112     finally:
113         loop.close()
114
115
116 @contextmanager
117 def skip_if_exception(e: str) -> Iterator[None]:
118     try:
119         yield
120     except Exception as exc:
121         if exc.__class__.__name__ == e:
122             unittest.skip(f"Encountered expected exception {exc}, skipping")
123         else:
124             raise
125
126
127 class FakeContext(click.Context):
128     """A fake click Context for when calling functions that need it."""
129
130     def __init__(self) -> None:
131         self.default_map: Dict[str, Any] = {}
132
133
134 class FakeParameter(click.Parameter):
135     """A fake click Parameter for when calling functions that need it."""
136
137     def __init__(self) -> None:
138         pass
139
140
141 class BlackRunner(CliRunner):
142     """Modify CliRunner so that stderr is not merged with stdout.
143
144     This is a hack that can be removed once we depend on Click 7.x"""
145
146     def __init__(self) -> None:
147         self.stderrbuf = BytesIO()
148         self.stdoutbuf = BytesIO()
149         self.stdout_bytes = b""
150         self.stderr_bytes = b""
151         super().__init__()
152
153     @contextmanager
154     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
155         with super().isolation(*args, **kwargs) as output:
156             try:
157                 hold_stderr = sys.stderr
158                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
159                 yield output
160             finally:
161                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
162                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
163                 sys.stderr = hold_stderr
164
165
166 class BlackTestCase(unittest.TestCase):
167     maxDiff = None
168     _diffThreshold = 2 ** 20
169
170     def assertFormatEqual(self, expected: str, actual: str) -> None:
171         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
172             bdv: black.DebugVisitor[Any]
173             black.out("Expected tree:", fg="green")
174             try:
175                 exp_node = black.lib2to3_parse(expected)
176                 bdv = black.DebugVisitor()
177                 list(bdv.visit(exp_node))
178             except Exception as ve:
179                 black.err(str(ve))
180             black.out("Actual tree:", fg="red")
181             try:
182                 exp_node = black.lib2to3_parse(actual)
183                 bdv = black.DebugVisitor()
184                 list(bdv.visit(exp_node))
185             except Exception as ve:
186                 black.err(str(ve))
187         self.assertMultiLineEqual(expected, actual)
188
189     def invokeBlack(
190         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
191     ) -> None:
192         runner = BlackRunner()
193         if ignore_config:
194             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
195         result = runner.invoke(black.main, args)
196         self.assertEqual(
197             result.exit_code,
198             exit_code,
199             msg=(
200                 f"Failed with args: {args}\n"
201                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
202                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
203                 f"exception: {result.exception}"
204             ),
205         )
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
209         path = THIS_DIR.parent / name
210         source, expected = read_data(str(path), data=False)
211         actual = fs(source, mode=mode)
212         self.assertFormatEqual(expected, actual)
213         black.assert_equivalent(source, actual)
214         black.assert_stable(source, actual, mode)
215         self.assertFalse(ff(path))
216
217     @patch("black.dump_to_file", dump_to_stderr)
218     def test_empty(self) -> None:
219         source = expected = ""
220         actual = fs(source)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, DEFAULT_MODE)
224
225     def test_empty_ff(self) -> None:
226         expected = ""
227         tmp_file = Path(black.dump_to_file())
228         try:
229             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
230             with open(tmp_file, encoding="utf8") as f:
231                 actual = f.read()
232         finally:
233             os.unlink(tmp_file)
234         self.assertFormatEqual(expected, actual)
235
236     def test_self(self) -> None:
237         self.checkSourceFile("tests/test_black.py")
238
239     def test_black(self) -> None:
240         self.checkSourceFile("src/black/__init__.py")
241
242     def test_pygram(self) -> None:
243         self.checkSourceFile("src/blib2to3/pygram.py")
244
245     def test_pytree(self) -> None:
246         self.checkSourceFile("src/blib2to3/pytree.py")
247
248     def test_conv(self) -> None:
249         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
250
251     def test_driver(self) -> None:
252         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
253
254     def test_grammar(self) -> None:
255         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
256
257     def test_literals(self) -> None:
258         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
259
260     def test_parse(self) -> None:
261         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
262
263     def test_pgen(self) -> None:
264         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
265
266     def test_tokenize(self) -> None:
267         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
268
269     def test_token(self) -> None:
270         self.checkSourceFile("src/blib2to3/pgen2/token.py")
271
272     def test_setup(self) -> None:
273         self.checkSourceFile("setup.py")
274
275     def test_piping(self) -> None:
276         source, expected = read_data("src/black/__init__", data=False)
277         result = BlackRunner().invoke(
278             black.main,
279             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
280             input=BytesIO(source.encode("utf8")),
281         )
282         self.assertEqual(result.exit_code, 0)
283         self.assertFormatEqual(expected, result.output)
284         black.assert_equivalent(source, result.output)
285         black.assert_stable(source, result.output, DEFAULT_MODE)
286
287     def test_piping_diff(self) -> None:
288         diff_header = re.compile(
289             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
290             r"\+\d\d\d\d"
291         )
292         source, _ = read_data("expression.py")
293         expected, _ = read_data("expression.diff")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         args = [
296             "-",
297             "--fast",
298             f"--line-length={black.DEFAULT_LINE_LENGTH}",
299             "--diff",
300             f"--config={config}",
301         ]
302         result = BlackRunner().invoke(
303             black.main, args, input=BytesIO(source.encode("utf8"))
304         )
305         self.assertEqual(result.exit_code, 0)
306         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
307         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
308         self.assertEqual(expected, actual)
309
310     def test_piping_diff_with_color(self) -> None:
311         source, _ = read_data("expression.py")
312         config = THIS_DIR / "data" / "empty_pyproject.toml"
313         args = [
314             "-",
315             "--fast",
316             f"--line-length={black.DEFAULT_LINE_LENGTH}",
317             "--diff",
318             "--color",
319             f"--config={config}",
320         ]
321         result = BlackRunner().invoke(
322             black.main, args, input=BytesIO(source.encode("utf8"))
323         )
324         actual = result.output
325         # Again, the contents are checked in a different test, so only look for colors.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_function(self) -> None:
334         source, expected = read_data("function")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, DEFAULT_MODE)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_function2(self) -> None:
342         source, expected = read_data("function2")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, DEFAULT_MODE)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_function_trailing_comma_wip(self) -> None:
350         source, expected = read_data("function_trailing_comma_wip")
351         # sys.settrace(tracefunc)
352         actual = fs(source)
353         # sys.settrace(None)
354         self.assertFormatEqual(expected, actual)
355         black.assert_equivalent(source, actual)
356         black.assert_stable(source, actual, black.FileMode())
357
358     @patch("black.dump_to_file", dump_to_stderr)
359     def test_function_trailing_comma(self) -> None:
360         source, expected = read_data("function_trailing_comma")
361         actual = fs(source)
362         self.assertFormatEqual(expected, actual)
363         black.assert_equivalent(source, actual)
364         black.assert_stable(source, actual, DEFAULT_MODE)
365
366     @patch("black.dump_to_file", dump_to_stderr)
367     def test_expression(self) -> None:
368         source, expected = read_data("expression")
369         actual = fs(source)
370         self.assertFormatEqual(expected, actual)
371         black.assert_equivalent(source, actual)
372         black.assert_stable(source, actual, DEFAULT_MODE)
373
374     @patch("black.dump_to_file", dump_to_stderr)
375     def test_pep_572(self) -> None:
376         source, expected = read_data("pep_572")
377         actual = fs(source)
378         self.assertFormatEqual(expected, actual)
379         black.assert_stable(source, actual, DEFAULT_MODE)
380         if sys.version_info >= (3, 8):
381             black.assert_equivalent(source, actual)
382
383     def test_pep_572_version_detection(self) -> None:
384         source, _ = read_data("pep_572")
385         root = black.lib2to3_parse(source)
386         features = black.get_features_used(root)
387         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
388         versions = black.detect_target_versions(root)
389         self.assertIn(black.TargetVersion.PY38, versions)
390
391     def test_expression_ff(self) -> None:
392         source, expected = read_data("expression")
393         tmp_file = Path(black.dump_to_file(source))
394         try:
395             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
396             with open(tmp_file, encoding="utf8") as f:
397                 actual = f.read()
398         finally:
399             os.unlink(tmp_file)
400         self.assertFormatEqual(expected, actual)
401         with patch("black.dump_to_file", dump_to_stderr):
402             black.assert_equivalent(source, actual)
403             black.assert_stable(source, actual, DEFAULT_MODE)
404
405     def test_expression_diff(self) -> None:
406         source, _ = read_data("expression.py")
407         expected, _ = read_data("expression.diff")
408         tmp_file = Path(black.dump_to_file(source))
409         diff_header = re.compile(
410             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
411             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
412         )
413         try:
414             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
415             self.assertEqual(result.exit_code, 0)
416         finally:
417             os.unlink(tmp_file)
418         actual = result.output
419         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
420         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
421         if expected != actual:
422             dump = black.dump_to_file(actual)
423             msg = (
424                 "Expected diff isn't equal to the actual. If you made changes to"
425                 " expression.py and this is an anticipated difference, overwrite"
426                 f" tests/data/expression.diff with {dump}"
427             )
428             self.assertEqual(expected, actual, msg)
429
430     def test_expression_diff_with_color(self) -> None:
431         source, _ = read_data("expression.py")
432         expected, _ = read_data("expression.diff")
433         tmp_file = Path(black.dump_to_file(source))
434         try:
435             result = BlackRunner().invoke(
436                 black.main, ["--diff", "--color", str(tmp_file)]
437             )
438         finally:
439             os.unlink(tmp_file)
440         actual = result.output
441         # We check the contents of the diff in `test_expression_diff`. All
442         # we need to check here is that color codes exist in the result.
443         self.assertIn("\033[1;37m", actual)
444         self.assertIn("\033[36m", actual)
445         self.assertIn("\033[32m", actual)
446         self.assertIn("\033[31m", actual)
447         self.assertIn("\033[0m", actual)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_fstring(self) -> None:
451         source, expected = read_data("fstring")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, DEFAULT_MODE)
456
457     @patch("black.dump_to_file", dump_to_stderr)
458     def test_pep_570(self) -> None:
459         source, expected = read_data("pep_570")
460         actual = fs(source)
461         self.assertFormatEqual(expected, actual)
462         black.assert_stable(source, actual, DEFAULT_MODE)
463         if sys.version_info >= (3, 8):
464             black.assert_equivalent(source, actual)
465
466     def test_detect_pos_only_arguments(self) -> None:
467         source, _ = read_data("pep_570")
468         root = black.lib2to3_parse(source)
469         features = black.get_features_used(root)
470         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
471         versions = black.detect_target_versions(root)
472         self.assertIn(black.TargetVersion.PY38, versions)
473
474     @patch("black.dump_to_file", dump_to_stderr)
475     def test_string_quotes(self) -> None:
476         source, expected = read_data("string_quotes")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         black.assert_equivalent(source, actual)
480         black.assert_stable(source, actual, DEFAULT_MODE)
481         mode = replace(DEFAULT_MODE, string_normalization=False)
482         not_normalized = fs(source, mode=mode)
483         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
484         black.assert_equivalent(source, not_normalized)
485         black.assert_stable(source, not_normalized, mode=mode)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_docstring(self) -> None:
489         source, expected = read_data("docstring")
490         actual = fs(source)
491         self.assertFormatEqual(expected, actual)
492         black.assert_equivalent(source, actual)
493         black.assert_stable(source, actual, DEFAULT_MODE)
494
495     def test_long_strings(self) -> None:
496         """Tests for splitting long strings."""
497         source, expected = read_data("long_strings")
498         actual = fs(source)
499         self.assertFormatEqual(expected, actual)
500         black.assert_equivalent(source, actual)
501         black.assert_stable(source, actual, DEFAULT_MODE)
502
503     def test_long_strings_flag_disabled(self) -> None:
504         """Tests for turning off the string processing logic."""
505         source, expected = read_data("long_strings_flag_disabled")
506         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
507         actual = fs(source, mode=mode)
508         self.assertFormatEqual(expected, actual)
509         black.assert_stable(expected, actual, mode)
510
511     @patch("black.dump_to_file", dump_to_stderr)
512     def test_long_strings__edge_case(self) -> None:
513         """Edge-case tests for splitting long strings."""
514         source, expected = read_data("long_strings__edge_case")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         black.assert_equivalent(source, actual)
518         black.assert_stable(source, actual, DEFAULT_MODE)
519
520     @patch("black.dump_to_file", dump_to_stderr)
521     def test_long_strings__regression(self) -> None:
522         """Regression tests for splitting long strings."""
523         source, expected = read_data("long_strings__regression")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, DEFAULT_MODE)
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_slices(self) -> None:
531         source, expected = read_data("slices")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, DEFAULT_MODE)
536
537     @patch("black.dump_to_file", dump_to_stderr)
538     def test_percent_precedence(self) -> None:
539         source, expected = read_data("percent_precedence")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         black.assert_equivalent(source, actual)
543         black.assert_stable(source, actual, DEFAULT_MODE)
544
545     @patch("black.dump_to_file", dump_to_stderr)
546     def test_comments(self) -> None:
547         source, expected = read_data("comments")
548         actual = fs(source)
549         self.assertFormatEqual(expected, actual)
550         black.assert_equivalent(source, actual)
551         black.assert_stable(source, actual, DEFAULT_MODE)
552
553     @patch("black.dump_to_file", dump_to_stderr)
554     def test_comments2(self) -> None:
555         source, expected = read_data("comments2")
556         actual = fs(source)
557         self.assertFormatEqual(expected, actual)
558         black.assert_equivalent(source, actual)
559         black.assert_stable(source, actual, DEFAULT_MODE)
560
561     @patch("black.dump_to_file", dump_to_stderr)
562     def test_comments3(self) -> None:
563         source, expected = read_data("comments3")
564         actual = fs(source)
565         self.assertFormatEqual(expected, actual)
566         black.assert_equivalent(source, actual)
567         black.assert_stable(source, actual, DEFAULT_MODE)
568
569     @patch("black.dump_to_file", dump_to_stderr)
570     def test_comments4(self) -> None:
571         source, expected = read_data("comments4")
572         actual = fs(source)
573         self.assertFormatEqual(expected, actual)
574         black.assert_equivalent(source, actual)
575         black.assert_stable(source, actual, DEFAULT_MODE)
576
577     @patch("black.dump_to_file", dump_to_stderr)
578     def test_comments5(self) -> None:
579         source, expected = read_data("comments5")
580         actual = fs(source)
581         self.assertFormatEqual(expected, actual)
582         black.assert_equivalent(source, actual)
583         black.assert_stable(source, actual, DEFAULT_MODE)
584
585     @patch("black.dump_to_file", dump_to_stderr)
586     def test_comments6(self) -> None:
587         source, expected = read_data("comments6")
588         actual = fs(source)
589         self.assertFormatEqual(expected, actual)
590         black.assert_equivalent(source, actual)
591         black.assert_stable(source, actual, DEFAULT_MODE)
592
593     @patch("black.dump_to_file", dump_to_stderr)
594     def test_comments7(self) -> None:
595         source, expected = read_data("comments7")
596         actual = fs(source)
597         self.assertFormatEqual(expected, actual)
598         black.assert_equivalent(source, actual)
599         black.assert_stable(source, actual, DEFAULT_MODE)
600
601     @patch("black.dump_to_file", dump_to_stderr)
602     def test_comment_after_escaped_newline(self) -> None:
603         source, expected = read_data("comment_after_escaped_newline")
604         actual = fs(source)
605         self.assertFormatEqual(expected, actual)
606         black.assert_equivalent(source, actual)
607         black.assert_stable(source, actual, DEFAULT_MODE)
608
609     @patch("black.dump_to_file", dump_to_stderr)
610     def test_cantfit(self) -> None:
611         source, expected = read_data("cantfit")
612         actual = fs(source)
613         self.assertFormatEqual(expected, actual)
614         black.assert_equivalent(source, actual)
615         black.assert_stable(source, actual, DEFAULT_MODE)
616
617     @patch("black.dump_to_file", dump_to_stderr)
618     def test_import_spacing(self) -> None:
619         source, expected = read_data("import_spacing")
620         actual = fs(source)
621         self.assertFormatEqual(expected, actual)
622         black.assert_equivalent(source, actual)
623         black.assert_stable(source, actual, DEFAULT_MODE)
624
625     @patch("black.dump_to_file", dump_to_stderr)
626     def test_composition(self) -> None:
627         source, expected = read_data("composition")
628         actual = fs(source)
629         self.assertFormatEqual(expected, actual)
630         black.assert_equivalent(source, actual)
631         black.assert_stable(source, actual, DEFAULT_MODE)
632
633     @patch("black.dump_to_file", dump_to_stderr)
634     def test_composition_no_trailing_comma(self) -> None:
635         source, expected = read_data("composition_no_trailing_comma")
636         actual = fs(source)
637         self.assertFormatEqual(expected, actual)
638         black.assert_equivalent(source, actual)
639         black.assert_stable(source, actual, DEFAULT_MODE)
640
641     @patch("black.dump_to_file", dump_to_stderr)
642     def test_empty_lines(self) -> None:
643         source, expected = read_data("empty_lines")
644         actual = fs(source)
645         self.assertFormatEqual(expected, actual)
646         black.assert_equivalent(source, actual)
647         black.assert_stable(source, actual, DEFAULT_MODE)
648
649     @patch("black.dump_to_file", dump_to_stderr)
650     def test_remove_parens(self) -> None:
651         source, expected = read_data("remove_parens")
652         actual = fs(source)
653         self.assertFormatEqual(expected, actual)
654         black.assert_equivalent(source, actual)
655         black.assert_stable(source, actual, DEFAULT_MODE)
656
657     @patch("black.dump_to_file", dump_to_stderr)
658     def test_string_prefixes(self) -> None:
659         source, expected = read_data("string_prefixes")
660         actual = fs(source)
661         self.assertFormatEqual(expected, actual)
662         black.assert_equivalent(source, actual)
663         black.assert_stable(source, actual, DEFAULT_MODE)
664
665     @patch("black.dump_to_file", dump_to_stderr)
666     def test_numeric_literals(self) -> None:
667         source, expected = read_data("numeric_literals")
668         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
669         actual = fs(source, mode=mode)
670         self.assertFormatEqual(expected, actual)
671         black.assert_equivalent(source, actual)
672         black.assert_stable(source, actual, mode)
673
674     @patch("black.dump_to_file", dump_to_stderr)
675     def test_numeric_literals_ignoring_underscores(self) -> None:
676         source, expected = read_data("numeric_literals_skip_underscores")
677         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
678         actual = fs(source, mode=mode)
679         self.assertFormatEqual(expected, actual)
680         black.assert_equivalent(source, actual)
681         black.assert_stable(source, actual, mode)
682
683     @patch("black.dump_to_file", dump_to_stderr)
684     def test_numeric_literals_py2(self) -> None:
685         source, expected = read_data("numeric_literals_py2")
686         actual = fs(source)
687         self.assertFormatEqual(expected, actual)
688         black.assert_stable(source, actual, DEFAULT_MODE)
689
690     @patch("black.dump_to_file", dump_to_stderr)
691     def test_python2(self) -> None:
692         source, expected = read_data("python2")
693         actual = fs(source)
694         self.assertFormatEqual(expected, actual)
695         black.assert_equivalent(source, actual)
696         black.assert_stable(source, actual, DEFAULT_MODE)
697
698     @patch("black.dump_to_file", dump_to_stderr)
699     def test_python2_print_function(self) -> None:
700         source, expected = read_data("python2_print_function")
701         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
702         actual = fs(source, mode=mode)
703         self.assertFormatEqual(expected, actual)
704         black.assert_equivalent(source, actual)
705         black.assert_stable(source, actual, mode)
706
707     @patch("black.dump_to_file", dump_to_stderr)
708     def test_python2_unicode_literals(self) -> None:
709         source, expected = read_data("python2_unicode_literals")
710         actual = fs(source)
711         self.assertFormatEqual(expected, actual)
712         black.assert_equivalent(source, actual)
713         black.assert_stable(source, actual, DEFAULT_MODE)
714
715     @patch("black.dump_to_file", dump_to_stderr)
716     def test_stub(self) -> None:
717         mode = replace(DEFAULT_MODE, is_pyi=True)
718         source, expected = read_data("stub.pyi")
719         actual = fs(source, mode=mode)
720         self.assertFormatEqual(expected, actual)
721         black.assert_stable(source, actual, mode)
722
723     @patch("black.dump_to_file", dump_to_stderr)
724     def test_async_as_identifier(self) -> None:
725         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
726         source, expected = read_data("async_as_identifier")
727         actual = fs(source)
728         self.assertFormatEqual(expected, actual)
729         major, minor = sys.version_info[:2]
730         if major < 3 or (major <= 3 and minor < 7):
731             black.assert_equivalent(source, actual)
732         black.assert_stable(source, actual, DEFAULT_MODE)
733         # ensure black can parse this when the target is 3.6
734         self.invokeBlack([str(source_path), "--target-version", "py36"])
735         # but not on 3.7, because async/await is no longer an identifier
736         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
737
738     @patch("black.dump_to_file", dump_to_stderr)
739     def test_python37(self) -> None:
740         source_path = (THIS_DIR / "data" / "python37.py").resolve()
741         source, expected = read_data("python37")
742         actual = fs(source)
743         self.assertFormatEqual(expected, actual)
744         major, minor = sys.version_info[:2]
745         if major > 3 or (major == 3 and minor >= 7):
746             black.assert_equivalent(source, actual)
747         black.assert_stable(source, actual, DEFAULT_MODE)
748         # ensure black can parse this when the target is 3.7
749         self.invokeBlack([str(source_path), "--target-version", "py37"])
750         # but not on 3.6, because we use async as a reserved keyword
751         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
752
753     @patch("black.dump_to_file", dump_to_stderr)
754     def test_python38(self) -> None:
755         source, expected = read_data("python38")
756         actual = fs(source)
757         self.assertFormatEqual(expected, actual)
758         major, minor = sys.version_info[:2]
759         if major > 3 or (major == 3 and minor >= 8):
760             black.assert_equivalent(source, actual)
761         black.assert_stable(source, actual, DEFAULT_MODE)
762
763     @patch("black.dump_to_file", dump_to_stderr)
764     def test_fmtonoff(self) -> None:
765         source, expected = read_data("fmtonoff")
766         actual = fs(source)
767         self.assertFormatEqual(expected, actual)
768         black.assert_equivalent(source, actual)
769         black.assert_stable(source, actual, DEFAULT_MODE)
770
771     @patch("black.dump_to_file", dump_to_stderr)
772     def test_fmtonoff2(self) -> None:
773         source, expected = read_data("fmtonoff2")
774         actual = fs(source)
775         self.assertFormatEqual(expected, actual)
776         black.assert_equivalent(source, actual)
777         black.assert_stable(source, actual, DEFAULT_MODE)
778
779     @patch("black.dump_to_file", dump_to_stderr)
780     def test_fmtonoff3(self) -> None:
781         source, expected = read_data("fmtonoff3")
782         actual = fs(source)
783         self.assertFormatEqual(expected, actual)
784         black.assert_equivalent(source, actual)
785         black.assert_stable(source, actual, DEFAULT_MODE)
786
787     @patch("black.dump_to_file", dump_to_stderr)
788     def test_fmtonoff4(self) -> None:
789         source, expected = read_data("fmtonoff4")
790         actual = fs(source)
791         self.assertFormatEqual(expected, actual)
792         black.assert_equivalent(source, actual)
793         black.assert_stable(source, actual, DEFAULT_MODE)
794
795     @patch("black.dump_to_file", dump_to_stderr)
796     def test_remove_empty_parentheses_after_class(self) -> None:
797         source, expected = read_data("class_blank_parentheses")
798         actual = fs(source)
799         self.assertFormatEqual(expected, actual)
800         black.assert_equivalent(source, actual)
801         black.assert_stable(source, actual, DEFAULT_MODE)
802
803     @patch("black.dump_to_file", dump_to_stderr)
804     def test_new_line_between_class_and_code(self) -> None:
805         source, expected = read_data("class_methods_new_line")
806         actual = fs(source)
807         self.assertFormatEqual(expected, actual)
808         black.assert_equivalent(source, actual)
809         black.assert_stable(source, actual, DEFAULT_MODE)
810
811     @patch("black.dump_to_file", dump_to_stderr)
812     def test_bracket_match(self) -> None:
813         source, expected = read_data("bracketmatch")
814         actual = fs(source)
815         self.assertFormatEqual(expected, actual)
816         black.assert_equivalent(source, actual)
817         black.assert_stable(source, actual, DEFAULT_MODE)
818
819     @patch("black.dump_to_file", dump_to_stderr)
820     def test_tuple_assign(self) -> None:
821         source, expected = read_data("tupleassign")
822         actual = fs(source)
823         self.assertFormatEqual(expected, actual)
824         black.assert_equivalent(source, actual)
825         black.assert_stable(source, actual, DEFAULT_MODE)
826
827     @patch("black.dump_to_file", dump_to_stderr)
828     def test_beginning_backslash(self) -> None:
829         source, expected = read_data("beginning_backslash")
830         actual = fs(source)
831         self.assertFormatEqual(expected, actual)
832         black.assert_equivalent(source, actual)
833         black.assert_stable(source, actual, DEFAULT_MODE)
834
835     def test_tab_comment_indentation(self) -> None:
836         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
837         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
838         self.assertFormatEqual(contents_spc, fs(contents_spc))
839         self.assertFormatEqual(contents_spc, fs(contents_tab))
840
841         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
842         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
843         self.assertFormatEqual(contents_spc, fs(contents_spc))
844         self.assertFormatEqual(contents_spc, fs(contents_tab))
845
846         # mixed tabs and spaces (valid Python 2 code)
847         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
848         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
849         self.assertFormatEqual(contents_spc, fs(contents_spc))
850         self.assertFormatEqual(contents_spc, fs(contents_tab))
851
852         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
853         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
854         self.assertFormatEqual(contents_spc, fs(contents_spc))
855         self.assertFormatEqual(contents_spc, fs(contents_tab))
856
857     def test_report_verbose(self) -> None:
858         report = black.Report(verbose=True)
859         out_lines = []
860         err_lines = []
861
862         def out(msg: str, **kwargs: Any) -> None:
863             out_lines.append(msg)
864
865         def err(msg: str, **kwargs: Any) -> None:
866             err_lines.append(msg)
867
868         with patch("black.out", out), patch("black.err", err):
869             report.done(Path("f1"), black.Changed.NO)
870             self.assertEqual(len(out_lines), 1)
871             self.assertEqual(len(err_lines), 0)
872             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
873             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
874             self.assertEqual(report.return_code, 0)
875             report.done(Path("f2"), black.Changed.YES)
876             self.assertEqual(len(out_lines), 2)
877             self.assertEqual(len(err_lines), 0)
878             self.assertEqual(out_lines[-1], "reformatted f2")
879             self.assertEqual(
880                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
881             )
882             report.done(Path("f3"), black.Changed.CACHED)
883             self.assertEqual(len(out_lines), 3)
884             self.assertEqual(len(err_lines), 0)
885             self.assertEqual(
886                 out_lines[-1], "f3 wasn't modified on disk since last run."
887             )
888             self.assertEqual(
889                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
890             )
891             self.assertEqual(report.return_code, 0)
892             report.check = True
893             self.assertEqual(report.return_code, 1)
894             report.check = False
895             report.failed(Path("e1"), "boom")
896             self.assertEqual(len(out_lines), 3)
897             self.assertEqual(len(err_lines), 1)
898             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
899             self.assertEqual(
900                 unstyle(str(report)),
901                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
902                 " reformat.",
903             )
904             self.assertEqual(report.return_code, 123)
905             report.done(Path("f3"), black.Changed.YES)
906             self.assertEqual(len(out_lines), 4)
907             self.assertEqual(len(err_lines), 1)
908             self.assertEqual(out_lines[-1], "reformatted f3")
909             self.assertEqual(
910                 unstyle(str(report)),
911                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
912                 " reformat.",
913             )
914             self.assertEqual(report.return_code, 123)
915             report.failed(Path("e2"), "boom")
916             self.assertEqual(len(out_lines), 4)
917             self.assertEqual(len(err_lines), 2)
918             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
919             self.assertEqual(
920                 unstyle(str(report)),
921                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
922                 " reformat.",
923             )
924             self.assertEqual(report.return_code, 123)
925             report.path_ignored(Path("wat"), "no match")
926             self.assertEqual(len(out_lines), 5)
927             self.assertEqual(len(err_lines), 2)
928             self.assertEqual(out_lines[-1], "wat ignored: no match")
929             self.assertEqual(
930                 unstyle(str(report)),
931                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
932                 " reformat.",
933             )
934             self.assertEqual(report.return_code, 123)
935             report.done(Path("f4"), black.Changed.NO)
936             self.assertEqual(len(out_lines), 6)
937             self.assertEqual(len(err_lines), 2)
938             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
939             self.assertEqual(
940                 unstyle(str(report)),
941                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
942                 " reformat.",
943             )
944             self.assertEqual(report.return_code, 123)
945             report.check = True
946             self.assertEqual(
947                 unstyle(str(report)),
948                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
949                 " would fail to reformat.",
950             )
951             report.check = False
952             report.diff = True
953             self.assertEqual(
954                 unstyle(str(report)),
955                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
956                 " would fail to reformat.",
957             )
958
959     def test_report_quiet(self) -> None:
960         report = black.Report(quiet=True)
961         out_lines = []
962         err_lines = []
963
964         def out(msg: str, **kwargs: Any) -> None:
965             out_lines.append(msg)
966
967         def err(msg: str, **kwargs: Any) -> None:
968             err_lines.append(msg)
969
970         with patch("black.out", out), patch("black.err", err):
971             report.done(Path("f1"), black.Changed.NO)
972             self.assertEqual(len(out_lines), 0)
973             self.assertEqual(len(err_lines), 0)
974             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
975             self.assertEqual(report.return_code, 0)
976             report.done(Path("f2"), black.Changed.YES)
977             self.assertEqual(len(out_lines), 0)
978             self.assertEqual(len(err_lines), 0)
979             self.assertEqual(
980                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
981             )
982             report.done(Path("f3"), black.Changed.CACHED)
983             self.assertEqual(len(out_lines), 0)
984             self.assertEqual(len(err_lines), 0)
985             self.assertEqual(
986                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
987             )
988             self.assertEqual(report.return_code, 0)
989             report.check = True
990             self.assertEqual(report.return_code, 1)
991             report.check = False
992             report.failed(Path("e1"), "boom")
993             self.assertEqual(len(out_lines), 0)
994             self.assertEqual(len(err_lines), 1)
995             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
996             self.assertEqual(
997                 unstyle(str(report)),
998                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
999                 " reformat.",
1000             )
1001             self.assertEqual(report.return_code, 123)
1002             report.done(Path("f3"), black.Changed.YES)
1003             self.assertEqual(len(out_lines), 0)
1004             self.assertEqual(len(err_lines), 1)
1005             self.assertEqual(
1006                 unstyle(str(report)),
1007                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1008                 " reformat.",
1009             )
1010             self.assertEqual(report.return_code, 123)
1011             report.failed(Path("e2"), "boom")
1012             self.assertEqual(len(out_lines), 0)
1013             self.assertEqual(len(err_lines), 2)
1014             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1015             self.assertEqual(
1016                 unstyle(str(report)),
1017                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1018                 " reformat.",
1019             )
1020             self.assertEqual(report.return_code, 123)
1021             report.path_ignored(Path("wat"), "no match")
1022             self.assertEqual(len(out_lines), 0)
1023             self.assertEqual(len(err_lines), 2)
1024             self.assertEqual(
1025                 unstyle(str(report)),
1026                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1027                 " reformat.",
1028             )
1029             self.assertEqual(report.return_code, 123)
1030             report.done(Path("f4"), black.Changed.NO)
1031             self.assertEqual(len(out_lines), 0)
1032             self.assertEqual(len(err_lines), 2)
1033             self.assertEqual(
1034                 unstyle(str(report)),
1035                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1036                 " reformat.",
1037             )
1038             self.assertEqual(report.return_code, 123)
1039             report.check = True
1040             self.assertEqual(
1041                 unstyle(str(report)),
1042                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1043                 " would fail to reformat.",
1044             )
1045             report.check = False
1046             report.diff = True
1047             self.assertEqual(
1048                 unstyle(str(report)),
1049                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1050                 " would fail to reformat.",
1051             )
1052
1053     def test_report_normal(self) -> None:
1054         report = black.Report()
1055         out_lines = []
1056         err_lines = []
1057
1058         def out(msg: str, **kwargs: Any) -> None:
1059             out_lines.append(msg)
1060
1061         def err(msg: str, **kwargs: Any) -> None:
1062             err_lines.append(msg)
1063
1064         with patch("black.out", out), patch("black.err", err):
1065             report.done(Path("f1"), black.Changed.NO)
1066             self.assertEqual(len(out_lines), 0)
1067             self.assertEqual(len(err_lines), 0)
1068             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1069             self.assertEqual(report.return_code, 0)
1070             report.done(Path("f2"), black.Changed.YES)
1071             self.assertEqual(len(out_lines), 1)
1072             self.assertEqual(len(err_lines), 0)
1073             self.assertEqual(out_lines[-1], "reformatted f2")
1074             self.assertEqual(
1075                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1076             )
1077             report.done(Path("f3"), black.Changed.CACHED)
1078             self.assertEqual(len(out_lines), 1)
1079             self.assertEqual(len(err_lines), 0)
1080             self.assertEqual(out_lines[-1], "reformatted f2")
1081             self.assertEqual(
1082                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1083             )
1084             self.assertEqual(report.return_code, 0)
1085             report.check = True
1086             self.assertEqual(report.return_code, 1)
1087             report.check = False
1088             report.failed(Path("e1"), "boom")
1089             self.assertEqual(len(out_lines), 1)
1090             self.assertEqual(len(err_lines), 1)
1091             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1092             self.assertEqual(
1093                 unstyle(str(report)),
1094                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1095                 " reformat.",
1096             )
1097             self.assertEqual(report.return_code, 123)
1098             report.done(Path("f3"), black.Changed.YES)
1099             self.assertEqual(len(out_lines), 2)
1100             self.assertEqual(len(err_lines), 1)
1101             self.assertEqual(out_lines[-1], "reformatted f3")
1102             self.assertEqual(
1103                 unstyle(str(report)),
1104                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1105                 " reformat.",
1106             )
1107             self.assertEqual(report.return_code, 123)
1108             report.failed(Path("e2"), "boom")
1109             self.assertEqual(len(out_lines), 2)
1110             self.assertEqual(len(err_lines), 2)
1111             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1112             self.assertEqual(
1113                 unstyle(str(report)),
1114                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1115                 " reformat.",
1116             )
1117             self.assertEqual(report.return_code, 123)
1118             report.path_ignored(Path("wat"), "no match")
1119             self.assertEqual(len(out_lines), 2)
1120             self.assertEqual(len(err_lines), 2)
1121             self.assertEqual(
1122                 unstyle(str(report)),
1123                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1124                 " reformat.",
1125             )
1126             self.assertEqual(report.return_code, 123)
1127             report.done(Path("f4"), black.Changed.NO)
1128             self.assertEqual(len(out_lines), 2)
1129             self.assertEqual(len(err_lines), 2)
1130             self.assertEqual(
1131                 unstyle(str(report)),
1132                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1133                 " reformat.",
1134             )
1135             self.assertEqual(report.return_code, 123)
1136             report.check = True
1137             self.assertEqual(
1138                 unstyle(str(report)),
1139                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1140                 " would fail to reformat.",
1141             )
1142             report.check = False
1143             report.diff = True
1144             self.assertEqual(
1145                 unstyle(str(report)),
1146                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1147                 " would fail to reformat.",
1148             )
1149
1150     def test_lib2to3_parse(self) -> None:
1151         with self.assertRaises(black.InvalidInput):
1152             black.lib2to3_parse("invalid syntax")
1153
1154         straddling = "x + y"
1155         black.lib2to3_parse(straddling)
1156         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1157         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1158         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1159
1160         py2_only = "print x"
1161         black.lib2to3_parse(py2_only)
1162         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1163         with self.assertRaises(black.InvalidInput):
1164             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1165         with self.assertRaises(black.InvalidInput):
1166             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1167
1168         py3_only = "exec(x, end=y)"
1169         black.lib2to3_parse(py3_only)
1170         with self.assertRaises(black.InvalidInput):
1171             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1172         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1173         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1174
1175     def test_get_features_used(self) -> None:
1176         node = black.lib2to3_parse("def f(*, arg): ...\n")
1177         self.assertEqual(black.get_features_used(node), set())
1178         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1179         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1180         node = black.lib2to3_parse("f(*arg,)\n")
1181         self.assertEqual(
1182             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1183         )
1184         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1185         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1186         node = black.lib2to3_parse("123_456\n")
1187         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1188         node = black.lib2to3_parse("123456\n")
1189         self.assertEqual(black.get_features_used(node), set())
1190         source, expected = read_data("function")
1191         node = black.lib2to3_parse(source)
1192         expected_features = {
1193             Feature.TRAILING_COMMA_IN_CALL,
1194             Feature.TRAILING_COMMA_IN_DEF,
1195             Feature.F_STRINGS,
1196         }
1197         self.assertEqual(black.get_features_used(node), expected_features)
1198         node = black.lib2to3_parse(expected)
1199         self.assertEqual(black.get_features_used(node), expected_features)
1200         source, expected = read_data("expression")
1201         node = black.lib2to3_parse(source)
1202         self.assertEqual(black.get_features_used(node), set())
1203         node = black.lib2to3_parse(expected)
1204         self.assertEqual(black.get_features_used(node), set())
1205
1206     def test_get_future_imports(self) -> None:
1207         node = black.lib2to3_parse("\n")
1208         self.assertEqual(set(), black.get_future_imports(node))
1209         node = black.lib2to3_parse("from __future__ import black\n")
1210         self.assertEqual({"black"}, black.get_future_imports(node))
1211         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1212         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1213         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1214         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1215         node = black.lib2to3_parse(
1216             "from __future__ import multiple\nfrom __future__ import imports\n"
1217         )
1218         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1219         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1220         self.assertEqual({"black"}, black.get_future_imports(node))
1221         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1222         self.assertEqual({"black"}, black.get_future_imports(node))
1223         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1224         self.assertEqual(set(), black.get_future_imports(node))
1225         node = black.lib2to3_parse("from some.module import black\n")
1226         self.assertEqual(set(), black.get_future_imports(node))
1227         node = black.lib2to3_parse(
1228             "from __future__ import unicode_literals as _unicode_literals"
1229         )
1230         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1231         node = black.lib2to3_parse(
1232             "from __future__ import unicode_literals as _lol, print"
1233         )
1234         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1235
1236     def test_debug_visitor(self) -> None:
1237         source, _ = read_data("debug_visitor.py")
1238         expected, _ = read_data("debug_visitor.out")
1239         out_lines = []
1240         err_lines = []
1241
1242         def out(msg: str, **kwargs: Any) -> None:
1243             out_lines.append(msg)
1244
1245         def err(msg: str, **kwargs: Any) -> None:
1246             err_lines.append(msg)
1247
1248         with patch("black.out", out), patch("black.err", err):
1249             black.DebugVisitor.show(source)
1250         actual = "\n".join(out_lines) + "\n"
1251         log_name = ""
1252         if expected != actual:
1253             log_name = black.dump_to_file(*out_lines)
1254         self.assertEqual(
1255             expected,
1256             actual,
1257             f"AST print out is different. Actual version dumped to {log_name}",
1258         )
1259
1260     def test_format_file_contents(self) -> None:
1261         empty = ""
1262         mode = DEFAULT_MODE
1263         with self.assertRaises(black.NothingChanged):
1264             black.format_file_contents(empty, mode=mode, fast=False)
1265         just_nl = "\n"
1266         with self.assertRaises(black.NothingChanged):
1267             black.format_file_contents(just_nl, mode=mode, fast=False)
1268         same = "j = [1, 2, 3]\n"
1269         with self.assertRaises(black.NothingChanged):
1270             black.format_file_contents(same, mode=mode, fast=False)
1271         different = "j = [1,2,3]"
1272         expected = same
1273         actual = black.format_file_contents(different, mode=mode, fast=False)
1274         self.assertEqual(expected, actual)
1275         invalid = "return if you can"
1276         with self.assertRaises(black.InvalidInput) as e:
1277             black.format_file_contents(invalid, mode=mode, fast=False)
1278         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1279
1280     def test_endmarker(self) -> None:
1281         n = black.lib2to3_parse("\n")
1282         self.assertEqual(n.type, black.syms.file_input)
1283         self.assertEqual(len(n.children), 1)
1284         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1285
1286     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1287     def test_assertFormatEqual(self) -> None:
1288         out_lines = []
1289         err_lines = []
1290
1291         def out(msg: str, **kwargs: Any) -> None:
1292             out_lines.append(msg)
1293
1294         def err(msg: str, **kwargs: Any) -> None:
1295             err_lines.append(msg)
1296
1297         with patch("black.out", out), patch("black.err", err):
1298             with self.assertRaises(AssertionError):
1299                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1300
1301         out_str = "".join(out_lines)
1302         self.assertTrue("Expected tree:" in out_str)
1303         self.assertTrue("Actual tree:" in out_str)
1304         self.assertEqual("".join(err_lines), "")
1305
1306     def test_cache_broken_file(self) -> None:
1307         mode = DEFAULT_MODE
1308         with cache_dir() as workspace:
1309             cache_file = black.get_cache_file(mode)
1310             with cache_file.open("w") as fobj:
1311                 fobj.write("this is not a pickle")
1312             self.assertEqual(black.read_cache(mode), {})
1313             src = (workspace / "test.py").resolve()
1314             with src.open("w") as fobj:
1315                 fobj.write("print('hello')")
1316             self.invokeBlack([str(src)])
1317             cache = black.read_cache(mode)
1318             self.assertIn(src, cache)
1319
1320     def test_cache_single_file_already_cached(self) -> None:
1321         mode = DEFAULT_MODE
1322         with cache_dir() as workspace:
1323             src = (workspace / "test.py").resolve()
1324             with src.open("w") as fobj:
1325                 fobj.write("print('hello')")
1326             black.write_cache({}, [src], mode)
1327             self.invokeBlack([str(src)])
1328             with src.open("r") as fobj:
1329                 self.assertEqual(fobj.read(), "print('hello')")
1330
1331     @event_loop()
1332     def test_cache_multiple_files(self) -> None:
1333         mode = DEFAULT_MODE
1334         with cache_dir() as workspace, patch(
1335             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1336         ):
1337             one = (workspace / "one.py").resolve()
1338             with one.open("w") as fobj:
1339                 fobj.write("print('hello')")
1340             two = (workspace / "two.py").resolve()
1341             with two.open("w") as fobj:
1342                 fobj.write("print('hello')")
1343             black.write_cache({}, [one], mode)
1344             self.invokeBlack([str(workspace)])
1345             with one.open("r") as fobj:
1346                 self.assertEqual(fobj.read(), "print('hello')")
1347             with two.open("r") as fobj:
1348                 self.assertEqual(fobj.read(), 'print("hello")\n')
1349             cache = black.read_cache(mode)
1350             self.assertIn(one, cache)
1351             self.assertIn(two, cache)
1352
1353     def test_no_cache_when_writeback_diff(self) -> None:
1354         mode = DEFAULT_MODE
1355         with cache_dir() as workspace:
1356             src = (workspace / "test.py").resolve()
1357             with src.open("w") as fobj:
1358                 fobj.write("print('hello')")
1359             self.invokeBlack([str(src), "--diff"])
1360             cache_file = black.get_cache_file(mode)
1361             self.assertFalse(cache_file.exists())
1362
1363     def test_no_cache_when_stdin(self) -> None:
1364         mode = DEFAULT_MODE
1365         with cache_dir():
1366             result = CliRunner().invoke(
1367                 black.main, ["-"], input=BytesIO(b"print('hello')")
1368             )
1369             self.assertEqual(result.exit_code, 0)
1370             cache_file = black.get_cache_file(mode)
1371             self.assertFalse(cache_file.exists())
1372
1373     def test_read_cache_no_cachefile(self) -> None:
1374         mode = DEFAULT_MODE
1375         with cache_dir():
1376             self.assertEqual(black.read_cache(mode), {})
1377
1378     def test_write_cache_read_cache(self) -> None:
1379         mode = DEFAULT_MODE
1380         with cache_dir() as workspace:
1381             src = (workspace / "test.py").resolve()
1382             src.touch()
1383             black.write_cache({}, [src], mode)
1384             cache = black.read_cache(mode)
1385             self.assertIn(src, cache)
1386             self.assertEqual(cache[src], black.get_cache_info(src))
1387
1388     def test_filter_cached(self) -> None:
1389         with TemporaryDirectory() as workspace:
1390             path = Path(workspace)
1391             uncached = (path / "uncached").resolve()
1392             cached = (path / "cached").resolve()
1393             cached_but_changed = (path / "changed").resolve()
1394             uncached.touch()
1395             cached.touch()
1396             cached_but_changed.touch()
1397             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1398             todo, done = black.filter_cached(
1399                 cache, {uncached, cached, cached_but_changed}
1400             )
1401             self.assertEqual(todo, {uncached, cached_but_changed})
1402             self.assertEqual(done, {cached})
1403
1404     def test_write_cache_creates_directory_if_needed(self) -> None:
1405         mode = DEFAULT_MODE
1406         with cache_dir(exists=False) as workspace:
1407             self.assertFalse(workspace.exists())
1408             black.write_cache({}, [], mode)
1409             self.assertTrue(workspace.exists())
1410
1411     @event_loop()
1412     def test_failed_formatting_does_not_get_cached(self) -> None:
1413         mode = DEFAULT_MODE
1414         with cache_dir() as workspace, patch(
1415             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1416         ):
1417             failing = (workspace / "failing.py").resolve()
1418             with failing.open("w") as fobj:
1419                 fobj.write("not actually python")
1420             clean = (workspace / "clean.py").resolve()
1421             with clean.open("w") as fobj:
1422                 fobj.write('print("hello")\n')
1423             self.invokeBlack([str(workspace)], exit_code=123)
1424             cache = black.read_cache(mode)
1425             self.assertNotIn(failing, cache)
1426             self.assertIn(clean, cache)
1427
1428     def test_write_cache_write_fail(self) -> None:
1429         mode = DEFAULT_MODE
1430         with cache_dir(), patch.object(Path, "open") as mock:
1431             mock.side_effect = OSError
1432             black.write_cache({}, [], mode)
1433
1434     @event_loop()
1435     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1436     def test_works_in_mono_process_only_environment(self) -> None:
1437         with cache_dir() as workspace:
1438             for f in [
1439                 (workspace / "one.py").resolve(),
1440                 (workspace / "two.py").resolve(),
1441             ]:
1442                 f.write_text('print("hello")\n')
1443             self.invokeBlack([str(workspace)])
1444
1445     @event_loop()
1446     def test_check_diff_use_together(self) -> None:
1447         with cache_dir():
1448             # Files which will be reformatted.
1449             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1450             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1451             # Files which will not be reformatted.
1452             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1453             self.invokeBlack([str(src2), "--diff", "--check"])
1454             # Multi file command.
1455             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1456
1457     def test_no_files(self) -> None:
1458         with cache_dir():
1459             # Without an argument, black exits with error code 0.
1460             self.invokeBlack([])
1461
1462     def test_broken_symlink(self) -> None:
1463         with cache_dir() as workspace:
1464             symlink = workspace / "broken_link.py"
1465             try:
1466                 symlink.symlink_to("nonexistent.py")
1467             except OSError as e:
1468                 self.skipTest(f"Can't create symlinks: {e}")
1469             self.invokeBlack([str(workspace.resolve())])
1470
1471     def test_read_cache_line_lengths(self) -> None:
1472         mode = DEFAULT_MODE
1473         short_mode = replace(DEFAULT_MODE, line_length=1)
1474         with cache_dir() as workspace:
1475             path = (workspace / "file.py").resolve()
1476             path.touch()
1477             black.write_cache({}, [path], mode)
1478             one = black.read_cache(mode)
1479             self.assertIn(path, one)
1480             two = black.read_cache(short_mode)
1481             self.assertNotIn(path, two)
1482
1483     def test_tricky_unicode_symbols(self) -> None:
1484         source, expected = read_data("tricky_unicode_symbols")
1485         actual = fs(source)
1486         self.assertFormatEqual(expected, actual)
1487         black.assert_equivalent(source, actual)
1488         black.assert_stable(source, actual, DEFAULT_MODE)
1489
1490     def test_single_file_force_pyi(self) -> None:
1491         reg_mode = DEFAULT_MODE
1492         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1493         contents, expected = read_data("force_pyi")
1494         with cache_dir() as workspace:
1495             path = (workspace / "file.py").resolve()
1496             with open(path, "w") as fh:
1497                 fh.write(contents)
1498             self.invokeBlack([str(path), "--pyi"])
1499             with open(path, "r") as fh:
1500                 actual = fh.read()
1501             # verify cache with --pyi is separate
1502             pyi_cache = black.read_cache(pyi_mode)
1503             self.assertIn(path, pyi_cache)
1504             normal_cache = black.read_cache(reg_mode)
1505             self.assertNotIn(path, normal_cache)
1506         self.assertEqual(actual, expected)
1507
1508     @event_loop()
1509     def test_multi_file_force_pyi(self) -> None:
1510         reg_mode = DEFAULT_MODE
1511         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1512         contents, expected = read_data("force_pyi")
1513         with cache_dir() as workspace:
1514             paths = [
1515                 (workspace / "file1.py").resolve(),
1516                 (workspace / "file2.py").resolve(),
1517             ]
1518             for path in paths:
1519                 with open(path, "w") as fh:
1520                     fh.write(contents)
1521             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1522             for path in paths:
1523                 with open(path, "r") as fh:
1524                     actual = fh.read()
1525                 self.assertEqual(actual, expected)
1526             # verify cache with --pyi is separate
1527             pyi_cache = black.read_cache(pyi_mode)
1528             normal_cache = black.read_cache(reg_mode)
1529             for path in paths:
1530                 self.assertIn(path, pyi_cache)
1531                 self.assertNotIn(path, normal_cache)
1532
1533     def test_pipe_force_pyi(self) -> None:
1534         source, expected = read_data("force_pyi")
1535         result = CliRunner().invoke(
1536             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1537         )
1538         self.assertEqual(result.exit_code, 0)
1539         actual = result.output
1540         self.assertFormatEqual(actual, expected)
1541
1542     def test_single_file_force_py36(self) -> None:
1543         reg_mode = DEFAULT_MODE
1544         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1545         source, expected = read_data("force_py36")
1546         with cache_dir() as workspace:
1547             path = (workspace / "file.py").resolve()
1548             with open(path, "w") as fh:
1549                 fh.write(source)
1550             self.invokeBlack([str(path), *PY36_ARGS])
1551             with open(path, "r") as fh:
1552                 actual = fh.read()
1553             # verify cache with --target-version is separate
1554             py36_cache = black.read_cache(py36_mode)
1555             self.assertIn(path, py36_cache)
1556             normal_cache = black.read_cache(reg_mode)
1557             self.assertNotIn(path, normal_cache)
1558         self.assertEqual(actual, expected)
1559
1560     @event_loop()
1561     def test_multi_file_force_py36(self) -> None:
1562         reg_mode = DEFAULT_MODE
1563         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1564         source, expected = read_data("force_py36")
1565         with cache_dir() as workspace:
1566             paths = [
1567                 (workspace / "file1.py").resolve(),
1568                 (workspace / "file2.py").resolve(),
1569             ]
1570             for path in paths:
1571                 with open(path, "w") as fh:
1572                     fh.write(source)
1573             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1574             for path in paths:
1575                 with open(path, "r") as fh:
1576                     actual = fh.read()
1577                 self.assertEqual(actual, expected)
1578             # verify cache with --target-version is separate
1579             pyi_cache = black.read_cache(py36_mode)
1580             normal_cache = black.read_cache(reg_mode)
1581             for path in paths:
1582                 self.assertIn(path, pyi_cache)
1583                 self.assertNotIn(path, normal_cache)
1584
1585     def test_collections(self) -> None:
1586         source, expected = read_data("collections")
1587         actual = fs(source)
1588         self.assertFormatEqual(expected, actual)
1589         black.assert_equivalent(source, actual)
1590         black.assert_stable(source, actual, DEFAULT_MODE)
1591
1592     def test_pipe_force_py36(self) -> None:
1593         source, expected = read_data("force_py36")
1594         result = CliRunner().invoke(
1595             black.main,
1596             ["-", "-q", "--target-version=py36"],
1597             input=BytesIO(source.encode("utf8")),
1598         )
1599         self.assertEqual(result.exit_code, 0)
1600         actual = result.output
1601         self.assertFormatEqual(actual, expected)
1602
1603     def test_include_exclude(self) -> None:
1604         path = THIS_DIR / "data" / "include_exclude_tests"
1605         include = re.compile(r"\.pyi?$")
1606         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1607         report = black.Report()
1608         gitignore = PathSpec.from_lines("gitwildmatch", [])
1609         sources: List[Path] = []
1610         expected = [
1611             Path(path / "b/dont_exclude/a.py"),
1612             Path(path / "b/dont_exclude/a.pyi"),
1613         ]
1614         this_abs = THIS_DIR.resolve()
1615         sources.extend(
1616             black.gen_python_files(
1617                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1618             )
1619         )
1620         self.assertEqual(sorted(expected), sorted(sources))
1621
1622     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1623     def test_exclude_for_issue_1572(self) -> None:
1624         # Exclude shouldn't touch files that were explicitly given to Black through the
1625         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1626         # https://github.com/psf/black/issues/1572
1627         path = THIS_DIR / "data" / "include_exclude_tests"
1628         include = ""
1629         exclude = r"/exclude/|a\.py"
1630         src = str(path / "b/exclude/a.py")
1631         report = black.Report()
1632         expected = [Path(path / "b/exclude/a.py")]
1633         sources = list(
1634             black.get_sources(
1635                 ctx=FakeContext(),
1636                 src=(src,),
1637                 quiet=True,
1638                 verbose=False,
1639                 include=include,
1640                 exclude=exclude,
1641                 force_exclude=None,
1642                 report=report,
1643             )
1644         )
1645         self.assertEqual(sorted(expected), sorted(sources))
1646
1647     def test_gitignore_exclude(self) -> None:
1648         path = THIS_DIR / "data" / "include_exclude_tests"
1649         include = re.compile(r"\.pyi?$")
1650         exclude = re.compile(r"")
1651         report = black.Report()
1652         gitignore = PathSpec.from_lines(
1653             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1654         )
1655         sources: List[Path] = []
1656         expected = [
1657             Path(path / "b/dont_exclude/a.py"),
1658             Path(path / "b/dont_exclude/a.pyi"),
1659         ]
1660         this_abs = THIS_DIR.resolve()
1661         sources.extend(
1662             black.gen_python_files(
1663                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1664             )
1665         )
1666         self.assertEqual(sorted(expected), sorted(sources))
1667
1668     def test_empty_include(self) -> None:
1669         path = THIS_DIR / "data" / "include_exclude_tests"
1670         report = black.Report()
1671         gitignore = PathSpec.from_lines("gitwildmatch", [])
1672         empty = re.compile(r"")
1673         sources: List[Path] = []
1674         expected = [
1675             Path(path / "b/exclude/a.pie"),
1676             Path(path / "b/exclude/a.py"),
1677             Path(path / "b/exclude/a.pyi"),
1678             Path(path / "b/dont_exclude/a.pie"),
1679             Path(path / "b/dont_exclude/a.py"),
1680             Path(path / "b/dont_exclude/a.pyi"),
1681             Path(path / "b/.definitely_exclude/a.pie"),
1682             Path(path / "b/.definitely_exclude/a.py"),
1683             Path(path / "b/.definitely_exclude/a.pyi"),
1684         ]
1685         this_abs = THIS_DIR.resolve()
1686         sources.extend(
1687             black.gen_python_files(
1688                 path.iterdir(),
1689                 this_abs,
1690                 empty,
1691                 re.compile(black.DEFAULT_EXCLUDES),
1692                 None,
1693                 report,
1694                 gitignore,
1695             )
1696         )
1697         self.assertEqual(sorted(expected), sorted(sources))
1698
1699     def test_empty_exclude(self) -> None:
1700         path = THIS_DIR / "data" / "include_exclude_tests"
1701         report = black.Report()
1702         gitignore = PathSpec.from_lines("gitwildmatch", [])
1703         empty = re.compile(r"")
1704         sources: List[Path] = []
1705         expected = [
1706             Path(path / "b/dont_exclude/a.py"),
1707             Path(path / "b/dont_exclude/a.pyi"),
1708             Path(path / "b/exclude/a.py"),
1709             Path(path / "b/exclude/a.pyi"),
1710             Path(path / "b/.definitely_exclude/a.py"),
1711             Path(path / "b/.definitely_exclude/a.pyi"),
1712         ]
1713         this_abs = THIS_DIR.resolve()
1714         sources.extend(
1715             black.gen_python_files(
1716                 path.iterdir(),
1717                 this_abs,
1718                 re.compile(black.DEFAULT_INCLUDES),
1719                 empty,
1720                 None,
1721                 report,
1722                 gitignore,
1723             )
1724         )
1725         self.assertEqual(sorted(expected), sorted(sources))
1726
1727     def test_invalid_include_exclude(self) -> None:
1728         for option in ["--include", "--exclude"]:
1729             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1730
1731     def test_preserves_line_endings(self) -> None:
1732         with TemporaryDirectory() as workspace:
1733             test_file = Path(workspace) / "test.py"
1734             for nl in ["\n", "\r\n"]:
1735                 contents = nl.join(["def f(  ):", "    pass"])
1736                 test_file.write_bytes(contents.encode())
1737                 ff(test_file, write_back=black.WriteBack.YES)
1738                 updated_contents: bytes = test_file.read_bytes()
1739                 self.assertIn(nl.encode(), updated_contents)
1740                 if nl == "\n":
1741                     self.assertNotIn(b"\r\n", updated_contents)
1742
1743     def test_preserves_line_endings_via_stdin(self) -> None:
1744         for nl in ["\n", "\r\n"]:
1745             contents = nl.join(["def f(  ):", "    pass"])
1746             runner = BlackRunner()
1747             result = runner.invoke(
1748                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1749             )
1750             self.assertEqual(result.exit_code, 0)
1751             output = runner.stdout_bytes
1752             self.assertIn(nl.encode("utf8"), output)
1753             if nl == "\n":
1754                 self.assertNotIn(b"\r\n", output)
1755
1756     def test_assert_equivalent_different_asts(self) -> None:
1757         with self.assertRaises(AssertionError):
1758             black.assert_equivalent("{}", "None")
1759
1760     def test_symlink_out_of_root_directory(self) -> None:
1761         path = MagicMock()
1762         root = THIS_DIR.resolve()
1763         child = MagicMock()
1764         include = re.compile(black.DEFAULT_INCLUDES)
1765         exclude = re.compile(black.DEFAULT_EXCLUDES)
1766         report = black.Report()
1767         gitignore = PathSpec.from_lines("gitwildmatch", [])
1768         # `child` should behave like a symlink which resolved path is clearly
1769         # outside of the `root` directory.
1770         path.iterdir.return_value = [child]
1771         child.resolve.return_value = Path("/a/b/c")
1772         child.as_posix.return_value = "/a/b/c"
1773         child.is_symlink.return_value = True
1774         try:
1775             list(
1776                 black.gen_python_files(
1777                     path.iterdir(), root, include, exclude, None, report, gitignore
1778                 )
1779             )
1780         except ValueError as ve:
1781             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1782         path.iterdir.assert_called_once()
1783         child.resolve.assert_called_once()
1784         child.is_symlink.assert_called_once()
1785         # `child` should behave like a strange file which resolved path is clearly
1786         # outside of the `root` directory.
1787         child.is_symlink.return_value = False
1788         with self.assertRaises(ValueError):
1789             list(
1790                 black.gen_python_files(
1791                     path.iterdir(), root, include, exclude, None, report, gitignore
1792                 )
1793             )
1794         path.iterdir.assert_called()
1795         self.assertEqual(path.iterdir.call_count, 2)
1796         child.resolve.assert_called()
1797         self.assertEqual(child.resolve.call_count, 2)
1798         child.is_symlink.assert_called()
1799         self.assertEqual(child.is_symlink.call_count, 2)
1800
1801     def test_shhh_click(self) -> None:
1802         try:
1803             from click import _unicodefun  # type: ignore
1804         except ModuleNotFoundError:
1805             self.skipTest("Incompatible Click version")
1806         if not hasattr(_unicodefun, "_verify_python3_env"):
1807             self.skipTest("Incompatible Click version")
1808         # First, let's see if Click is crashing with a preferred ASCII charset.
1809         with patch("locale.getpreferredencoding") as gpe:
1810             gpe.return_value = "ASCII"
1811             with self.assertRaises(RuntimeError):
1812                 _unicodefun._verify_python3_env()
1813         # Now, let's silence Click...
1814         black.patch_click()
1815         # ...and confirm it's silent.
1816         with patch("locale.getpreferredencoding") as gpe:
1817             gpe.return_value = "ASCII"
1818             try:
1819                 _unicodefun._verify_python3_env()
1820             except RuntimeError as re:
1821                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1822
1823     def test_root_logger_not_used_directly(self) -> None:
1824         def fail(*args: Any, **kwargs: Any) -> None:
1825             self.fail("Record created with root logger")
1826
1827         with patch.multiple(
1828             logging.root,
1829             debug=fail,
1830             info=fail,
1831             warning=fail,
1832             error=fail,
1833             critical=fail,
1834             log=fail,
1835         ):
1836             ff(THIS_FILE)
1837
1838     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1839     def test_blackd_main(self) -> None:
1840         with patch("blackd.web.run_app"):
1841             result = CliRunner().invoke(blackd.main, [])
1842             if result.exception is not None:
1843                 raise result.exception
1844             self.assertEqual(result.exit_code, 0)
1845
1846     def test_invalid_config_return_code(self) -> None:
1847         tmp_file = Path(black.dump_to_file())
1848         try:
1849             tmp_config = Path(black.dump_to_file())
1850             tmp_config.unlink()
1851             args = ["--config", str(tmp_config), str(tmp_file)]
1852             self.invokeBlack(args, exit_code=2, ignore_config=False)
1853         finally:
1854             tmp_file.unlink()
1855
1856     def test_parse_pyproject_toml(self) -> None:
1857         test_toml_file = THIS_DIR / "test.toml"
1858         config = black.parse_pyproject_toml(str(test_toml_file))
1859         self.assertEqual(config["verbose"], 1)
1860         self.assertEqual(config["check"], "no")
1861         self.assertEqual(config["diff"], "y")
1862         self.assertEqual(config["color"], True)
1863         self.assertEqual(config["line_length"], 79)
1864         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1865         self.assertEqual(config["exclude"], r"\.pyi?$")
1866         self.assertEqual(config["include"], r"\.py?$")
1867
1868     def test_read_pyproject_toml(self) -> None:
1869         test_toml_file = THIS_DIR / "test.toml"
1870         fake_ctx = FakeContext()
1871         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1872         config = fake_ctx.default_map
1873         self.assertEqual(config["verbose"], "1")
1874         self.assertEqual(config["check"], "no")
1875         self.assertEqual(config["diff"], "y")
1876         self.assertEqual(config["color"], "True")
1877         self.assertEqual(config["line_length"], "79")
1878         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1879         self.assertEqual(config["exclude"], r"\.pyi?$")
1880         self.assertEqual(config["include"], r"\.py?$")
1881
1882     def test_find_project_root(self) -> None:
1883         with TemporaryDirectory() as workspace:
1884             root = Path(workspace)
1885             test_dir = root / "test"
1886             test_dir.mkdir()
1887
1888             src_dir = root / "src"
1889             src_dir.mkdir()
1890
1891             root_pyproject = root / "pyproject.toml"
1892             root_pyproject.touch()
1893             src_pyproject = src_dir / "pyproject.toml"
1894             src_pyproject.touch()
1895             src_python = src_dir / "foo.py"
1896             src_python.touch()
1897
1898             self.assertEqual(
1899                 black.find_project_root((src_dir, test_dir)), root.resolve()
1900             )
1901             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1902             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1903
1904
1905 class BlackDTestCase(AioHTTPTestCase):
1906     async def get_application(self) -> web.Application:
1907         return blackd.make_app()
1908
1909     # TODO: remove these decorators once the below is released
1910     # https://github.com/aio-libs/aiohttp/pull/3727
1911     @skip_if_exception("ClientOSError")
1912     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1913     @unittest_run_loop
1914     async def test_blackd_request_needs_formatting(self) -> None:
1915         response = await self.client.post("/", data=b"print('hello world')")
1916         self.assertEqual(response.status, 200)
1917         self.assertEqual(response.charset, "utf8")
1918         self.assertEqual(await response.read(), b'print("hello world")\n')
1919
1920     @skip_if_exception("ClientOSError")
1921     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1922     @unittest_run_loop
1923     async def test_blackd_request_no_change(self) -> None:
1924         response = await self.client.post("/", data=b'print("hello world")\n')
1925         self.assertEqual(response.status, 204)
1926         self.assertEqual(await response.read(), b"")
1927
1928     @skip_if_exception("ClientOSError")
1929     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1930     @unittest_run_loop
1931     async def test_blackd_request_syntax_error(self) -> None:
1932         response = await self.client.post("/", data=b"what even ( is")
1933         self.assertEqual(response.status, 400)
1934         content = await response.text()
1935         self.assertTrue(
1936             content.startswith("Cannot parse"),
1937             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1938         )
1939
1940     @skip_if_exception("ClientOSError")
1941     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1942     @unittest_run_loop
1943     async def test_blackd_unsupported_version(self) -> None:
1944         response = await self.client.post(
1945             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1946         )
1947         self.assertEqual(response.status, 501)
1948
1949     @skip_if_exception("ClientOSError")
1950     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1951     @unittest_run_loop
1952     async def test_blackd_supported_version(self) -> None:
1953         response = await self.client.post(
1954             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1955         )
1956         self.assertEqual(response.status, 200)
1957
1958     @skip_if_exception("ClientOSError")
1959     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1960     @unittest_run_loop
1961     async def test_blackd_invalid_python_variant(self) -> None:
1962         async def check(header_value: str, expected_status: int = 400) -> None:
1963             response = await self.client.post(
1964                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1965             )
1966             self.assertEqual(response.status, expected_status)
1967
1968         await check("lol")
1969         await check("ruby3.5")
1970         await check("pyi3.6")
1971         await check("py1.5")
1972         await check("2.8")
1973         await check("py2.8")
1974         await check("3.0")
1975         await check("pypy3.0")
1976         await check("jython3.4")
1977
1978     @skip_if_exception("ClientOSError")
1979     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1980     @unittest_run_loop
1981     async def test_blackd_pyi(self) -> None:
1982         source, expected = read_data("stub.pyi")
1983         response = await self.client.post(
1984             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1985         )
1986         self.assertEqual(response.status, 200)
1987         self.assertEqual(await response.text(), expected)
1988
1989     @skip_if_exception("ClientOSError")
1990     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1991     @unittest_run_loop
1992     async def test_blackd_diff(self) -> None:
1993         diff_header = re.compile(
1994             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1995         )
1996
1997         source, _ = read_data("blackd_diff.py")
1998         expected, _ = read_data("blackd_diff.diff")
1999
2000         response = await self.client.post(
2001             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2002         )
2003         self.assertEqual(response.status, 200)
2004
2005         actual = await response.text()
2006         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2007         self.assertEqual(actual, expected)
2008
2009     @skip_if_exception("ClientOSError")
2010     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2011     @unittest_run_loop
2012     async def test_blackd_python_variant(self) -> None:
2013         code = (
2014             "def f(\n"
2015             "    and_has_a_bunch_of,\n"
2016             "    very_long_arguments_too,\n"
2017             "    and_lots_of_them_as_well_lol,\n"
2018             "    **and_very_long_keyword_arguments\n"
2019             "):\n"
2020             "    pass\n"
2021         )
2022
2023         async def check(header_value: str, expected_status: int) -> None:
2024             response = await self.client.post(
2025                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2026             )
2027             self.assertEqual(
2028                 response.status, expected_status, msg=await response.text()
2029             )
2030
2031         await check("3.6", 200)
2032         await check("py3.6", 200)
2033         await check("3.6,3.7", 200)
2034         await check("3.6,py3.7", 200)
2035         await check("py36,py37", 200)
2036         await check("36", 200)
2037         await check("3.6.4", 200)
2038
2039         await check("2", 204)
2040         await check("2.7", 204)
2041         await check("py2.7", 204)
2042         await check("3.4", 204)
2043         await check("py3.4", 204)
2044         await check("py34,py36", 204)
2045         await check("34", 204)
2046
2047     @skip_if_exception("ClientOSError")
2048     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2049     @unittest_run_loop
2050     async def test_blackd_line_length(self) -> None:
2051         response = await self.client.post(
2052             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2053         )
2054         self.assertEqual(response.status, 200)
2055
2056     @skip_if_exception("ClientOSError")
2057     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2058     @unittest_run_loop
2059     async def test_blackd_invalid_line_length(self) -> None:
2060         response = await self.client.post(
2061             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2062         )
2063         self.assertEqual(response.status, 400)
2064
2065     @skip_if_exception("ClientOSError")
2066     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2067     @unittest_run_loop
2068     async def test_blackd_response_black_version_header(self) -> None:
2069         response = await self.client.post("/")
2070         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2071
2072
2073 with open(black.__file__, "r", encoding="utf-8") as _bf:
2074     black_source_lines = _bf.readlines()
2075
2076
2077 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2078     """Show function calls `from black/__init__.py` as they happen.
2079
2080     Register this with `sys.settrace()` in a test you're debugging.
2081     """
2082     if event != "call":
2083         return tracefunc
2084
2085     stack = len(inspect.stack()) - 19
2086     filename = frame.f_code.co_filename
2087     lineno = frame.f_lineno
2088     func_sig_lineno = lineno - 1
2089     funcname = black_source_lines[func_sig_lineno].strip()
2090     while funcname.startswith("@"):
2091         func_sig_lineno += 1
2092         funcname = black_source_lines[func_sig_lineno].strip()
2093     if "black/__init__.py" in filename:
2094         print(f"{' ' * stack}{lineno}:{funcname}")
2095     return tracefunc
2096
2097
2098 if __name__ == "__main__":
2099     unittest.main(module="test_black")