]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix unstable subscript assignment string wrapping (#1678)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 from functools import partial
9 import inspect
10 from io import BytesIO, TextIOWrapper
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     BinaryIO,
21     Callable,
22     Dict,
23     Generator,
24     List,
25     Tuple,
26     Iterator,
27     TypeVar,
28 )
29 import unittest
30 from unittest.mock import patch, MagicMock
31
32 import click
33 from click import unstyle
34 from click.testing import CliRunner
35
36 import black
37 from black import Feature, TargetVersion
38
39 try:
40     import blackd
41     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
42     from aiohttp import web
43 except ImportError:
44     has_blackd_deps = False
45 else:
46     has_blackd_deps = True
47
48 from pathspec import PathSpec
49
50 # Import other test classes
51 from .test_primer import PrimerCLITests  # noqa: F401
52
53
54 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
55 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
56 fs = partial(black.format_str, mode=DEFAULT_MODE)
57 THIS_FILE = Path(__file__)
58 THIS_DIR = THIS_FILE.parent
59 PROJECT_ROOT = THIS_DIR.parent
60 DETERMINISTIC_HEADER = "[Deterministic header]"
61 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
62 PY36_ARGS = [
63     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
64 ]
65 T = TypeVar("T")
66 R = TypeVar("R")
67
68
69 def dump_to_stderr(*output: str) -> str:
70     return "\n" + "\n".join(output) + "\n"
71
72
73 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
74     """read_data('test_name') -> 'input', 'output'"""
75     if not name.endswith((".py", ".pyi", ".out", ".diff")):
76         name += ".py"
77     _input: List[str] = []
78     _output: List[str] = []
79     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
80     with open(base_dir / name, "r", encoding="utf8") as test:
81         lines = test.readlines()
82     result = _input
83     for line in lines:
84         line = line.replace(EMPTY_LINE, "")
85         if line.rstrip() == "# output":
86             result = _output
87             continue
88
89         result.append(line)
90     if _input and not _output:
91         # If there's no output marker, treat the entire file as already pre-formatted.
92         _output = _input[:]
93     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
94
95
96 @contextmanager
97 def cache_dir(exists: bool = True) -> Iterator[Path]:
98     with TemporaryDirectory() as workspace:
99         cache_dir = Path(workspace)
100         if not exists:
101             cache_dir = cache_dir / "new"
102         with patch("black.CACHE_DIR", cache_dir):
103             yield cache_dir
104
105
106 @contextmanager
107 def event_loop() -> Iterator[None]:
108     policy = asyncio.get_event_loop_policy()
109     loop = policy.new_event_loop()
110     asyncio.set_event_loop(loop)
111     try:
112         yield
113
114     finally:
115         loop.close()
116
117
118 @contextmanager
119 def skip_if_exception(e: str) -> Iterator[None]:
120     try:
121         yield
122     except Exception as exc:
123         if exc.__class__.__name__ == e:
124             unittest.skip(f"Encountered expected exception {exc}, skipping")
125         else:
126             raise
127
128
129 class FakeContext(click.Context):
130     """A fake click Context for when calling functions that need it."""
131
132     def __init__(self) -> None:
133         self.default_map: Dict[str, Any] = {}
134
135
136 class FakeParameter(click.Parameter):
137     """A fake click Parameter for when calling functions that need it."""
138
139     def __init__(self) -> None:
140         pass
141
142
143 class BlackRunner(CliRunner):
144     """Modify CliRunner so that stderr is not merged with stdout.
145
146     This is a hack that can be removed once we depend on Click 7.x"""
147
148     def __init__(self) -> None:
149         self.stderrbuf = BytesIO()
150         self.stdoutbuf = BytesIO()
151         self.stdout_bytes = b""
152         self.stderr_bytes = b""
153         super().__init__()
154
155     @contextmanager
156     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
157         with super().isolation(*args, **kwargs) as output:
158             try:
159                 hold_stderr = sys.stderr
160                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
161                 yield output
162             finally:
163                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
164                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
165                 sys.stderr = hold_stderr
166
167
168 class BlackTestCase(unittest.TestCase):
169     maxDiff = None
170     _diffThreshold = 2 ** 20
171
172     def assertFormatEqual(self, expected: str, actual: str) -> None:
173         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
174             bdv: black.DebugVisitor[Any]
175             black.out("Expected tree:", fg="green")
176             try:
177                 exp_node = black.lib2to3_parse(expected)
178                 bdv = black.DebugVisitor()
179                 list(bdv.visit(exp_node))
180             except Exception as ve:
181                 black.err(str(ve))
182             black.out("Actual tree:", fg="red")
183             try:
184                 exp_node = black.lib2to3_parse(actual)
185                 bdv = black.DebugVisitor()
186                 list(bdv.visit(exp_node))
187             except Exception as ve:
188                 black.err(str(ve))
189         self.assertMultiLineEqual(expected, actual)
190
191     def invokeBlack(
192         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
193     ) -> None:
194         runner = BlackRunner()
195         if ignore_config:
196             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
197         result = runner.invoke(black.main, args)
198         self.assertEqual(
199             result.exit_code,
200             exit_code,
201             msg=(
202                 f"Failed with args: {args}\n"
203                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
204                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
205                 f"exception: {result.exception}"
206             ),
207         )
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
211         path = THIS_DIR.parent / name
212         source, expected = read_data(str(path), data=False)
213         actual = fs(source, mode=mode)
214         self.assertFormatEqual(expected, actual)
215         black.assert_equivalent(source, actual)
216         black.assert_stable(source, actual, mode)
217         self.assertFalse(ff(path))
218
219     @patch("black.dump_to_file", dump_to_stderr)
220     def test_empty(self) -> None:
221         source = expected = ""
222         actual = fs(source)
223         self.assertFormatEqual(expected, actual)
224         black.assert_equivalent(source, actual)
225         black.assert_stable(source, actual, DEFAULT_MODE)
226
227     def test_empty_ff(self) -> None:
228         expected = ""
229         tmp_file = Path(black.dump_to_file())
230         try:
231             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
232             with open(tmp_file, encoding="utf8") as f:
233                 actual = f.read()
234         finally:
235             os.unlink(tmp_file)
236         self.assertFormatEqual(expected, actual)
237
238     def test_self(self) -> None:
239         self.checkSourceFile("tests/test_black.py")
240
241     def test_black(self) -> None:
242         self.checkSourceFile("src/black/__init__.py")
243
244     def test_pygram(self) -> None:
245         self.checkSourceFile("src/blib2to3/pygram.py")
246
247     def test_pytree(self) -> None:
248         self.checkSourceFile("src/blib2to3/pytree.py")
249
250     def test_conv(self) -> None:
251         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
252
253     def test_driver(self) -> None:
254         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
255
256     def test_grammar(self) -> None:
257         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
258
259     def test_literals(self) -> None:
260         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
261
262     def test_parse(self) -> None:
263         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
264
265     def test_pgen(self) -> None:
266         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
267
268     def test_tokenize(self) -> None:
269         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
270
271     def test_token(self) -> None:
272         self.checkSourceFile("src/blib2to3/pgen2/token.py")
273
274     def test_setup(self) -> None:
275         self.checkSourceFile("setup.py")
276
277     def test_piping(self) -> None:
278         source, expected = read_data("src/black/__init__", data=False)
279         result = BlackRunner().invoke(
280             black.main,
281             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
282             input=BytesIO(source.encode("utf8")),
283         )
284         self.assertEqual(result.exit_code, 0)
285         self.assertFormatEqual(expected, result.output)
286         black.assert_equivalent(source, result.output)
287         black.assert_stable(source, result.output, DEFAULT_MODE)
288
289     def test_piping_diff(self) -> None:
290         diff_header = re.compile(
291             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
292             r"\+\d\d\d\d"
293         )
294         source, _ = read_data("expression.py")
295         expected, _ = read_data("expression.diff")
296         config = THIS_DIR / "data" / "empty_pyproject.toml"
297         args = [
298             "-",
299             "--fast",
300             f"--line-length={black.DEFAULT_LINE_LENGTH}",
301             "--diff",
302             f"--config={config}",
303         ]
304         result = BlackRunner().invoke(
305             black.main, args, input=BytesIO(source.encode("utf8"))
306         )
307         self.assertEqual(result.exit_code, 0)
308         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
309         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
310         self.assertEqual(expected, actual)
311
312     def test_piping_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         args = [
316             "-",
317             "--fast",
318             f"--line-length={black.DEFAULT_LINE_LENGTH}",
319             "--diff",
320             "--color",
321             f"--config={config}",
322         ]
323         result = BlackRunner().invoke(
324             black.main, args, input=BytesIO(source.encode("utf8"))
325         )
326         actual = result.output
327         # Again, the contents are checked in a different test, so only look for colors.
328         self.assertIn("\033[1;37m", actual)
329         self.assertIn("\033[36m", actual)
330         self.assertIn("\033[32m", actual)
331         self.assertIn("\033[31m", actual)
332         self.assertIn("\033[0m", actual)
333
334     @patch("black.dump_to_file", dump_to_stderr)
335     def test_function(self) -> None:
336         source, expected = read_data("function")
337         actual = fs(source)
338         self.assertFormatEqual(expected, actual)
339         black.assert_equivalent(source, actual)
340         black.assert_stable(source, actual, DEFAULT_MODE)
341
342     @patch("black.dump_to_file", dump_to_stderr)
343     def test_function2(self) -> None:
344         source, expected = read_data("function2")
345         actual = fs(source)
346         self.assertFormatEqual(expected, actual)
347         black.assert_equivalent(source, actual)
348         black.assert_stable(source, actual, DEFAULT_MODE)
349
350     @patch("black.dump_to_file", dump_to_stderr)
351     def _test_wip(self) -> None:
352         source, expected = read_data("wip")
353         sys.settrace(tracefunc)
354         mode = replace(
355             DEFAULT_MODE,
356             experimental_string_processing=False,
357             target_versions={black.TargetVersion.PY38},
358         )
359         actual = fs(source, mode=mode)
360         sys.settrace(None)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, black.FileMode())
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_function_trailing_comma(self) -> None:
367         source, expected = read_data("function_trailing_comma")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, DEFAULT_MODE)
372
373     @unittest.expectedFailure
374     @patch("black.dump_to_file", dump_to_stderr)
375     def test_trailing_comma_optional_parens_stability1(self) -> None:
376         source, _expected = read_data("trailing_comma_optional_parens1")
377         actual = fs(source)
378         black.assert_stable(source, actual, DEFAULT_MODE)
379
380     @unittest.expectedFailure
381     @patch("black.dump_to_file", dump_to_stderr)
382     def test_trailing_comma_optional_parens_stability2(self) -> None:
383         source, _expected = read_data("trailing_comma_optional_parens2")
384         actual = fs(source)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386
387     @unittest.expectedFailure
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_trailing_comma_optional_parens_stability3(self) -> None:
390         source, _expected = read_data("trailing_comma_optional_parens3")
391         actual = fs(source)
392         black.assert_stable(source, actual, DEFAULT_MODE)
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_expression(self) -> None:
396         source, expected = read_data("expression")
397         actual = fs(source)
398         self.assertFormatEqual(expected, actual)
399         black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_pep_572(self) -> None:
404         source, expected = read_data("pep_572")
405         actual = fs(source)
406         self.assertFormatEqual(expected, actual)
407         black.assert_stable(source, actual, DEFAULT_MODE)
408         if sys.version_info >= (3, 8):
409             black.assert_equivalent(source, actual)
410
411     def test_pep_572_version_detection(self) -> None:
412         source, _ = read_data("pep_572")
413         root = black.lib2to3_parse(source)
414         features = black.get_features_used(root)
415         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
416         versions = black.detect_target_versions(root)
417         self.assertIn(black.TargetVersion.PY38, versions)
418
419     def test_expression_ff(self) -> None:
420         source, expected = read_data("expression")
421         tmp_file = Path(black.dump_to_file(source))
422         try:
423             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
424             with open(tmp_file, encoding="utf8") as f:
425                 actual = f.read()
426         finally:
427             os.unlink(tmp_file)
428         self.assertFormatEqual(expected, actual)
429         with patch("black.dump_to_file", dump_to_stderr):
430             black.assert_equivalent(source, actual)
431             black.assert_stable(source, actual, DEFAULT_MODE)
432
433     def test_expression_diff(self) -> None:
434         source, _ = read_data("expression.py")
435         expected, _ = read_data("expression.diff")
436         tmp_file = Path(black.dump_to_file(source))
437         diff_header = re.compile(
438             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
439             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
440         )
441         try:
442             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
443             self.assertEqual(result.exit_code, 0)
444         finally:
445             os.unlink(tmp_file)
446         actual = result.output
447         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
448         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
449         if expected != actual:
450             dump = black.dump_to_file(actual)
451             msg = (
452                 "Expected diff isn't equal to the actual. If you made changes to"
453                 " expression.py and this is an anticipated difference, overwrite"
454                 f" tests/data/expression.diff with {dump}"
455             )
456             self.assertEqual(expected, actual, msg)
457
458     def test_expression_diff_with_color(self) -> None:
459         source, _ = read_data("expression.py")
460         expected, _ = read_data("expression.diff")
461         tmp_file = Path(black.dump_to_file(source))
462         try:
463             result = BlackRunner().invoke(
464                 black.main, ["--diff", "--color", str(tmp_file)]
465             )
466         finally:
467             os.unlink(tmp_file)
468         actual = result.output
469         # We check the contents of the diff in `test_expression_diff`. All
470         # we need to check here is that color codes exist in the result.
471         self.assertIn("\033[1;37m", actual)
472         self.assertIn("\033[36m", actual)
473         self.assertIn("\033[32m", actual)
474         self.assertIn("\033[31m", actual)
475         self.assertIn("\033[0m", actual)
476
477     @patch("black.dump_to_file", dump_to_stderr)
478     def test_fstring(self) -> None:
479         source, expected = read_data("fstring")
480         actual = fs(source)
481         self.assertFormatEqual(expected, actual)
482         black.assert_equivalent(source, actual)
483         black.assert_stable(source, actual, DEFAULT_MODE)
484
485     @patch("black.dump_to_file", dump_to_stderr)
486     def test_pep_570(self) -> None:
487         source, expected = read_data("pep_570")
488         actual = fs(source)
489         self.assertFormatEqual(expected, actual)
490         black.assert_stable(source, actual, DEFAULT_MODE)
491         if sys.version_info >= (3, 8):
492             black.assert_equivalent(source, actual)
493
494     def test_detect_pos_only_arguments(self) -> None:
495         source, _ = read_data("pep_570")
496         root = black.lib2to3_parse(source)
497         features = black.get_features_used(root)
498         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
499         versions = black.detect_target_versions(root)
500         self.assertIn(black.TargetVersion.PY38, versions)
501
502     @patch("black.dump_to_file", dump_to_stderr)
503     def test_string_quotes(self) -> None:
504         source, expected = read_data("string_quotes")
505         actual = fs(source)
506         self.assertFormatEqual(expected, actual)
507         black.assert_equivalent(source, actual)
508         black.assert_stable(source, actual, DEFAULT_MODE)
509         mode = replace(DEFAULT_MODE, string_normalization=False)
510         not_normalized = fs(source, mode=mode)
511         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
512         black.assert_equivalent(source, not_normalized)
513         black.assert_stable(source, not_normalized, mode=mode)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_docstring(self) -> None:
517         source, expected = read_data("docstring")
518         actual = fs(source)
519         self.assertFormatEqual(expected, actual)
520         black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, DEFAULT_MODE)
522
523     @patch("black.dump_to_file", dump_to_stderr)
524     def test_docstring_no_string_normalization(self) -> None:
525         """Like test_docstring but with string normalization off."""
526         source, expected = read_data("docstring_no_string_normalization")
527         mode = replace(DEFAULT_MODE, string_normalization=False)
528         actual = fs(source, mode=mode)
529         self.assertFormatEqual(expected, actual)
530         black.assert_equivalent(source, actual)
531         black.assert_stable(source, actual, mode)
532
533     def test_long_strings(self) -> None:
534         """Tests for splitting long strings."""
535         source, expected = read_data("long_strings")
536         actual = fs(source)
537         self.assertFormatEqual(expected, actual)
538         black.assert_equivalent(source, actual)
539         black.assert_stable(source, actual, DEFAULT_MODE)
540
541     def test_long_strings_flag_disabled(self) -> None:
542         """Tests for turning off the string processing logic."""
543         source, expected = read_data("long_strings_flag_disabled")
544         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
545         actual = fs(source, mode=mode)
546         self.assertFormatEqual(expected, actual)
547         black.assert_stable(expected, actual, mode)
548
549     @patch("black.dump_to_file", dump_to_stderr)
550     def test_long_strings__edge_case(self) -> None:
551         """Edge-case tests for splitting long strings."""
552         source, expected = read_data("long_strings__edge_case")
553         actual = fs(source)
554         self.assertFormatEqual(expected, actual)
555         black.assert_equivalent(source, actual)
556         black.assert_stable(source, actual, DEFAULT_MODE)
557
558     @patch("black.dump_to_file", dump_to_stderr)
559     def test_long_strings__regression(self) -> None:
560         """Regression tests for splitting long strings."""
561         source, expected = read_data("long_strings__regression")
562         actual = fs(source)
563         self.assertFormatEqual(expected, actual)
564         black.assert_equivalent(source, actual)
565         black.assert_stable(source, actual, DEFAULT_MODE)
566
567     @patch("black.dump_to_file", dump_to_stderr)
568     def test_slices(self) -> None:
569         source, expected = read_data("slices")
570         actual = fs(source)
571         self.assertFormatEqual(expected, actual)
572         black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, DEFAULT_MODE)
574
575     @patch("black.dump_to_file", dump_to_stderr)
576     def test_percent_precedence(self) -> None:
577         source, expected = read_data("percent_precedence")
578         actual = fs(source)
579         self.assertFormatEqual(expected, actual)
580         black.assert_equivalent(source, actual)
581         black.assert_stable(source, actual, DEFAULT_MODE)
582
583     @patch("black.dump_to_file", dump_to_stderr)
584     def test_comments(self) -> None:
585         source, expected = read_data("comments")
586         actual = fs(source)
587         self.assertFormatEqual(expected, actual)
588         black.assert_equivalent(source, actual)
589         black.assert_stable(source, actual, DEFAULT_MODE)
590
591     @patch("black.dump_to_file", dump_to_stderr)
592     def test_comments2(self) -> None:
593         source, expected = read_data("comments2")
594         actual = fs(source)
595         self.assertFormatEqual(expected, actual)
596         black.assert_equivalent(source, actual)
597         black.assert_stable(source, actual, DEFAULT_MODE)
598
599     @patch("black.dump_to_file", dump_to_stderr)
600     def test_comments3(self) -> None:
601         source, expected = read_data("comments3")
602         actual = fs(source)
603         self.assertFormatEqual(expected, actual)
604         black.assert_equivalent(source, actual)
605         black.assert_stable(source, actual, DEFAULT_MODE)
606
607     @patch("black.dump_to_file", dump_to_stderr)
608     def test_comments4(self) -> None:
609         source, expected = read_data("comments4")
610         actual = fs(source)
611         self.assertFormatEqual(expected, actual)
612         black.assert_equivalent(source, actual)
613         black.assert_stable(source, actual, DEFAULT_MODE)
614
615     @patch("black.dump_to_file", dump_to_stderr)
616     def test_comments5(self) -> None:
617         source, expected = read_data("comments5")
618         actual = fs(source)
619         self.assertFormatEqual(expected, actual)
620         black.assert_equivalent(source, actual)
621         black.assert_stable(source, actual, DEFAULT_MODE)
622
623     @patch("black.dump_to_file", dump_to_stderr)
624     def test_comments6(self) -> None:
625         source, expected = read_data("comments6")
626         actual = fs(source)
627         self.assertFormatEqual(expected, actual)
628         black.assert_equivalent(source, actual)
629         black.assert_stable(source, actual, DEFAULT_MODE)
630
631     @patch("black.dump_to_file", dump_to_stderr)
632     def test_comments7(self) -> None:
633         source, expected = read_data("comments7")
634         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
635         actual = fs(source, mode=mode)
636         self.assertFormatEqual(expected, actual)
637         black.assert_equivalent(source, actual)
638         black.assert_stable(source, actual, DEFAULT_MODE)
639
640     @patch("black.dump_to_file", dump_to_stderr)
641     def test_comment_after_escaped_newline(self) -> None:
642         source, expected = read_data("comment_after_escaped_newline")
643         actual = fs(source)
644         self.assertFormatEqual(expected, actual)
645         black.assert_equivalent(source, actual)
646         black.assert_stable(source, actual, DEFAULT_MODE)
647
648     @patch("black.dump_to_file", dump_to_stderr)
649     def test_cantfit(self) -> None:
650         source, expected = read_data("cantfit")
651         actual = fs(source)
652         self.assertFormatEqual(expected, actual)
653         black.assert_equivalent(source, actual)
654         black.assert_stable(source, actual, DEFAULT_MODE)
655
656     @patch("black.dump_to_file", dump_to_stderr)
657     def test_import_spacing(self) -> None:
658         source, expected = read_data("import_spacing")
659         actual = fs(source)
660         self.assertFormatEqual(expected, actual)
661         black.assert_equivalent(source, actual)
662         black.assert_stable(source, actual, DEFAULT_MODE)
663
664     @patch("black.dump_to_file", dump_to_stderr)
665     def test_composition(self) -> None:
666         source, expected = read_data("composition")
667         actual = fs(source)
668         self.assertFormatEqual(expected, actual)
669         black.assert_equivalent(source, actual)
670         black.assert_stable(source, actual, DEFAULT_MODE)
671
672     @patch("black.dump_to_file", dump_to_stderr)
673     def test_composition_no_trailing_comma(self) -> None:
674         source, expected = read_data("composition_no_trailing_comma")
675         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
676         actual = fs(source, mode=mode)
677         self.assertFormatEqual(expected, actual)
678         black.assert_equivalent(source, actual)
679         black.assert_stable(source, actual, DEFAULT_MODE)
680
681     @patch("black.dump_to_file", dump_to_stderr)
682     def test_empty_lines(self) -> None:
683         source, expected = read_data("empty_lines")
684         actual = fs(source)
685         self.assertFormatEqual(expected, actual)
686         black.assert_equivalent(source, actual)
687         black.assert_stable(source, actual, DEFAULT_MODE)
688
689     @patch("black.dump_to_file", dump_to_stderr)
690     def test_remove_parens(self) -> None:
691         source, expected = read_data("remove_parens")
692         actual = fs(source)
693         self.assertFormatEqual(expected, actual)
694         black.assert_equivalent(source, actual)
695         black.assert_stable(source, actual, DEFAULT_MODE)
696
697     @patch("black.dump_to_file", dump_to_stderr)
698     def test_string_prefixes(self) -> None:
699         source, expected = read_data("string_prefixes")
700         actual = fs(source)
701         self.assertFormatEqual(expected, actual)
702         black.assert_equivalent(source, actual)
703         black.assert_stable(source, actual, DEFAULT_MODE)
704
705     @patch("black.dump_to_file", dump_to_stderr)
706     def test_numeric_literals(self) -> None:
707         source, expected = read_data("numeric_literals")
708         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
709         actual = fs(source, mode=mode)
710         self.assertFormatEqual(expected, actual)
711         black.assert_equivalent(source, actual)
712         black.assert_stable(source, actual, mode)
713
714     @patch("black.dump_to_file", dump_to_stderr)
715     def test_numeric_literals_ignoring_underscores(self) -> None:
716         source, expected = read_data("numeric_literals_skip_underscores")
717         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
718         actual = fs(source, mode=mode)
719         self.assertFormatEqual(expected, actual)
720         black.assert_equivalent(source, actual)
721         black.assert_stable(source, actual, mode)
722
723     @patch("black.dump_to_file", dump_to_stderr)
724     def test_numeric_literals_py2(self) -> None:
725         source, expected = read_data("numeric_literals_py2")
726         actual = fs(source)
727         self.assertFormatEqual(expected, actual)
728         black.assert_stable(source, actual, DEFAULT_MODE)
729
730     @patch("black.dump_to_file", dump_to_stderr)
731     def test_python2(self) -> None:
732         source, expected = read_data("python2")
733         actual = fs(source)
734         self.assertFormatEqual(expected, actual)
735         black.assert_equivalent(source, actual)
736         black.assert_stable(source, actual, DEFAULT_MODE)
737
738     @patch("black.dump_to_file", dump_to_stderr)
739     def test_python2_print_function(self) -> None:
740         source, expected = read_data("python2_print_function")
741         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
742         actual = fs(source, mode=mode)
743         self.assertFormatEqual(expected, actual)
744         black.assert_equivalent(source, actual)
745         black.assert_stable(source, actual, mode)
746
747     @patch("black.dump_to_file", dump_to_stderr)
748     def test_python2_unicode_literals(self) -> None:
749         source, expected = read_data("python2_unicode_literals")
750         actual = fs(source)
751         self.assertFormatEqual(expected, actual)
752         black.assert_equivalent(source, actual)
753         black.assert_stable(source, actual, DEFAULT_MODE)
754
755     @patch("black.dump_to_file", dump_to_stderr)
756     def test_stub(self) -> None:
757         mode = replace(DEFAULT_MODE, is_pyi=True)
758         source, expected = read_data("stub.pyi")
759         actual = fs(source, mode=mode)
760         self.assertFormatEqual(expected, actual)
761         black.assert_stable(source, actual, mode)
762
763     @patch("black.dump_to_file", dump_to_stderr)
764     def test_async_as_identifier(self) -> None:
765         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
766         source, expected = read_data("async_as_identifier")
767         actual = fs(source)
768         self.assertFormatEqual(expected, actual)
769         major, minor = sys.version_info[:2]
770         if major < 3 or (major <= 3 and minor < 7):
771             black.assert_equivalent(source, actual)
772         black.assert_stable(source, actual, DEFAULT_MODE)
773         # ensure black can parse this when the target is 3.6
774         self.invokeBlack([str(source_path), "--target-version", "py36"])
775         # but not on 3.7, because async/await is no longer an identifier
776         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
777
778     @patch("black.dump_to_file", dump_to_stderr)
779     def test_python37(self) -> None:
780         source_path = (THIS_DIR / "data" / "python37.py").resolve()
781         source, expected = read_data("python37")
782         actual = fs(source)
783         self.assertFormatEqual(expected, actual)
784         major, minor = sys.version_info[:2]
785         if major > 3 or (major == 3 and minor >= 7):
786             black.assert_equivalent(source, actual)
787         black.assert_stable(source, actual, DEFAULT_MODE)
788         # ensure black can parse this when the target is 3.7
789         self.invokeBlack([str(source_path), "--target-version", "py37"])
790         # but not on 3.6, because we use async as a reserved keyword
791         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
792
793     @patch("black.dump_to_file", dump_to_stderr)
794     def test_python38(self) -> None:
795         source, expected = read_data("python38")
796         actual = fs(source)
797         self.assertFormatEqual(expected, actual)
798         major, minor = sys.version_info[:2]
799         if major > 3 or (major == 3 and minor >= 8):
800             black.assert_equivalent(source, actual)
801         black.assert_stable(source, actual, DEFAULT_MODE)
802
803     @patch("black.dump_to_file", dump_to_stderr)
804     def test_fmtonoff(self) -> None:
805         source, expected = read_data("fmtonoff")
806         actual = fs(source)
807         self.assertFormatEqual(expected, actual)
808         black.assert_equivalent(source, actual)
809         black.assert_stable(source, actual, DEFAULT_MODE)
810
811     @patch("black.dump_to_file", dump_to_stderr)
812     def test_fmtonoff2(self) -> None:
813         source, expected = read_data("fmtonoff2")
814         actual = fs(source)
815         self.assertFormatEqual(expected, actual)
816         black.assert_equivalent(source, actual)
817         black.assert_stable(source, actual, DEFAULT_MODE)
818
819     @patch("black.dump_to_file", dump_to_stderr)
820     def test_fmtonoff3(self) -> None:
821         source, expected = read_data("fmtonoff3")
822         actual = fs(source)
823         self.assertFormatEqual(expected, actual)
824         black.assert_equivalent(source, actual)
825         black.assert_stable(source, actual, DEFAULT_MODE)
826
827     @patch("black.dump_to_file", dump_to_stderr)
828     def test_fmtonoff4(self) -> None:
829         source, expected = read_data("fmtonoff4")
830         actual = fs(source)
831         self.assertFormatEqual(expected, actual)
832         black.assert_equivalent(source, actual)
833         black.assert_stable(source, actual, DEFAULT_MODE)
834
835     @patch("black.dump_to_file", dump_to_stderr)
836     def test_remove_empty_parentheses_after_class(self) -> None:
837         source, expected = read_data("class_blank_parentheses")
838         actual = fs(source)
839         self.assertFormatEqual(expected, actual)
840         black.assert_equivalent(source, actual)
841         black.assert_stable(source, actual, DEFAULT_MODE)
842
843     @patch("black.dump_to_file", dump_to_stderr)
844     def test_new_line_between_class_and_code(self) -> None:
845         source, expected = read_data("class_methods_new_line")
846         actual = fs(source)
847         self.assertFormatEqual(expected, actual)
848         black.assert_equivalent(source, actual)
849         black.assert_stable(source, actual, DEFAULT_MODE)
850
851     @patch("black.dump_to_file", dump_to_stderr)
852     def test_bracket_match(self) -> None:
853         source, expected = read_data("bracketmatch")
854         actual = fs(source)
855         self.assertFormatEqual(expected, actual)
856         black.assert_equivalent(source, actual)
857         black.assert_stable(source, actual, DEFAULT_MODE)
858
859     @patch("black.dump_to_file", dump_to_stderr)
860     def test_tuple_assign(self) -> None:
861         source, expected = read_data("tupleassign")
862         actual = fs(source)
863         self.assertFormatEqual(expected, actual)
864         black.assert_equivalent(source, actual)
865         black.assert_stable(source, actual, DEFAULT_MODE)
866
867     @patch("black.dump_to_file", dump_to_stderr)
868     def test_beginning_backslash(self) -> None:
869         source, expected = read_data("beginning_backslash")
870         actual = fs(source)
871         self.assertFormatEqual(expected, actual)
872         black.assert_equivalent(source, actual)
873         black.assert_stable(source, actual, DEFAULT_MODE)
874
875     def test_tab_comment_indentation(self) -> None:
876         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
877         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
878         self.assertFormatEqual(contents_spc, fs(contents_spc))
879         self.assertFormatEqual(contents_spc, fs(contents_tab))
880
881         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
882         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
883         self.assertFormatEqual(contents_spc, fs(contents_spc))
884         self.assertFormatEqual(contents_spc, fs(contents_tab))
885
886         # mixed tabs and spaces (valid Python 2 code)
887         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
888         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
889         self.assertFormatEqual(contents_spc, fs(contents_spc))
890         self.assertFormatEqual(contents_spc, fs(contents_tab))
891
892         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
893         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
894         self.assertFormatEqual(contents_spc, fs(contents_spc))
895         self.assertFormatEqual(contents_spc, fs(contents_tab))
896
897     def test_report_verbose(self) -> None:
898         report = black.Report(verbose=True)
899         out_lines = []
900         err_lines = []
901
902         def out(msg: str, **kwargs: Any) -> None:
903             out_lines.append(msg)
904
905         def err(msg: str, **kwargs: Any) -> None:
906             err_lines.append(msg)
907
908         with patch("black.out", out), patch("black.err", err):
909             report.done(Path("f1"), black.Changed.NO)
910             self.assertEqual(len(out_lines), 1)
911             self.assertEqual(len(err_lines), 0)
912             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
913             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
914             self.assertEqual(report.return_code, 0)
915             report.done(Path("f2"), black.Changed.YES)
916             self.assertEqual(len(out_lines), 2)
917             self.assertEqual(len(err_lines), 0)
918             self.assertEqual(out_lines[-1], "reformatted f2")
919             self.assertEqual(
920                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
921             )
922             report.done(Path("f3"), black.Changed.CACHED)
923             self.assertEqual(len(out_lines), 3)
924             self.assertEqual(len(err_lines), 0)
925             self.assertEqual(
926                 out_lines[-1], "f3 wasn't modified on disk since last run."
927             )
928             self.assertEqual(
929                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
930             )
931             self.assertEqual(report.return_code, 0)
932             report.check = True
933             self.assertEqual(report.return_code, 1)
934             report.check = False
935             report.failed(Path("e1"), "boom")
936             self.assertEqual(len(out_lines), 3)
937             self.assertEqual(len(err_lines), 1)
938             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
939             self.assertEqual(
940                 unstyle(str(report)),
941                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
942                 " reformat.",
943             )
944             self.assertEqual(report.return_code, 123)
945             report.done(Path("f3"), black.Changed.YES)
946             self.assertEqual(len(out_lines), 4)
947             self.assertEqual(len(err_lines), 1)
948             self.assertEqual(out_lines[-1], "reformatted f3")
949             self.assertEqual(
950                 unstyle(str(report)),
951                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
952                 " reformat.",
953             )
954             self.assertEqual(report.return_code, 123)
955             report.failed(Path("e2"), "boom")
956             self.assertEqual(len(out_lines), 4)
957             self.assertEqual(len(err_lines), 2)
958             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
959             self.assertEqual(
960                 unstyle(str(report)),
961                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
962                 " reformat.",
963             )
964             self.assertEqual(report.return_code, 123)
965             report.path_ignored(Path("wat"), "no match")
966             self.assertEqual(len(out_lines), 5)
967             self.assertEqual(len(err_lines), 2)
968             self.assertEqual(out_lines[-1], "wat ignored: no match")
969             self.assertEqual(
970                 unstyle(str(report)),
971                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
972                 " reformat.",
973             )
974             self.assertEqual(report.return_code, 123)
975             report.done(Path("f4"), black.Changed.NO)
976             self.assertEqual(len(out_lines), 6)
977             self.assertEqual(len(err_lines), 2)
978             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
979             self.assertEqual(
980                 unstyle(str(report)),
981                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
982                 " reformat.",
983             )
984             self.assertEqual(report.return_code, 123)
985             report.check = True
986             self.assertEqual(
987                 unstyle(str(report)),
988                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
989                 " would fail to reformat.",
990             )
991             report.check = False
992             report.diff = True
993             self.assertEqual(
994                 unstyle(str(report)),
995                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
996                 " would fail to reformat.",
997             )
998
999     def test_report_quiet(self) -> None:
1000         report = black.Report(quiet=True)
1001         out_lines = []
1002         err_lines = []
1003
1004         def out(msg: str, **kwargs: Any) -> None:
1005             out_lines.append(msg)
1006
1007         def err(msg: str, **kwargs: Any) -> None:
1008             err_lines.append(msg)
1009
1010         with patch("black.out", out), patch("black.err", err):
1011             report.done(Path("f1"), black.Changed.NO)
1012             self.assertEqual(len(out_lines), 0)
1013             self.assertEqual(len(err_lines), 0)
1014             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1015             self.assertEqual(report.return_code, 0)
1016             report.done(Path("f2"), black.Changed.YES)
1017             self.assertEqual(len(out_lines), 0)
1018             self.assertEqual(len(err_lines), 0)
1019             self.assertEqual(
1020                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1021             )
1022             report.done(Path("f3"), black.Changed.CACHED)
1023             self.assertEqual(len(out_lines), 0)
1024             self.assertEqual(len(err_lines), 0)
1025             self.assertEqual(
1026                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1027             )
1028             self.assertEqual(report.return_code, 0)
1029             report.check = True
1030             self.assertEqual(report.return_code, 1)
1031             report.check = False
1032             report.failed(Path("e1"), "boom")
1033             self.assertEqual(len(out_lines), 0)
1034             self.assertEqual(len(err_lines), 1)
1035             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1036             self.assertEqual(
1037                 unstyle(str(report)),
1038                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1039                 " reformat.",
1040             )
1041             self.assertEqual(report.return_code, 123)
1042             report.done(Path("f3"), black.Changed.YES)
1043             self.assertEqual(len(out_lines), 0)
1044             self.assertEqual(len(err_lines), 1)
1045             self.assertEqual(
1046                 unstyle(str(report)),
1047                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1048                 " reformat.",
1049             )
1050             self.assertEqual(report.return_code, 123)
1051             report.failed(Path("e2"), "boom")
1052             self.assertEqual(len(out_lines), 0)
1053             self.assertEqual(len(err_lines), 2)
1054             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1055             self.assertEqual(
1056                 unstyle(str(report)),
1057                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1058                 " reformat.",
1059             )
1060             self.assertEqual(report.return_code, 123)
1061             report.path_ignored(Path("wat"), "no match")
1062             self.assertEqual(len(out_lines), 0)
1063             self.assertEqual(len(err_lines), 2)
1064             self.assertEqual(
1065                 unstyle(str(report)),
1066                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1067                 " reformat.",
1068             )
1069             self.assertEqual(report.return_code, 123)
1070             report.done(Path("f4"), black.Changed.NO)
1071             self.assertEqual(len(out_lines), 0)
1072             self.assertEqual(len(err_lines), 2)
1073             self.assertEqual(
1074                 unstyle(str(report)),
1075                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1076                 " reformat.",
1077             )
1078             self.assertEqual(report.return_code, 123)
1079             report.check = True
1080             self.assertEqual(
1081                 unstyle(str(report)),
1082                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1083                 " would fail to reformat.",
1084             )
1085             report.check = False
1086             report.diff = True
1087             self.assertEqual(
1088                 unstyle(str(report)),
1089                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1090                 " would fail to reformat.",
1091             )
1092
1093     def test_report_normal(self) -> None:
1094         report = black.Report()
1095         out_lines = []
1096         err_lines = []
1097
1098         def out(msg: str, **kwargs: Any) -> None:
1099             out_lines.append(msg)
1100
1101         def err(msg: str, **kwargs: Any) -> None:
1102             err_lines.append(msg)
1103
1104         with patch("black.out", out), patch("black.err", err):
1105             report.done(Path("f1"), black.Changed.NO)
1106             self.assertEqual(len(out_lines), 0)
1107             self.assertEqual(len(err_lines), 0)
1108             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1109             self.assertEqual(report.return_code, 0)
1110             report.done(Path("f2"), black.Changed.YES)
1111             self.assertEqual(len(out_lines), 1)
1112             self.assertEqual(len(err_lines), 0)
1113             self.assertEqual(out_lines[-1], "reformatted f2")
1114             self.assertEqual(
1115                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1116             )
1117             report.done(Path("f3"), black.Changed.CACHED)
1118             self.assertEqual(len(out_lines), 1)
1119             self.assertEqual(len(err_lines), 0)
1120             self.assertEqual(out_lines[-1], "reformatted f2")
1121             self.assertEqual(
1122                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1123             )
1124             self.assertEqual(report.return_code, 0)
1125             report.check = True
1126             self.assertEqual(report.return_code, 1)
1127             report.check = False
1128             report.failed(Path("e1"), "boom")
1129             self.assertEqual(len(out_lines), 1)
1130             self.assertEqual(len(err_lines), 1)
1131             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1132             self.assertEqual(
1133                 unstyle(str(report)),
1134                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1135                 " reformat.",
1136             )
1137             self.assertEqual(report.return_code, 123)
1138             report.done(Path("f3"), black.Changed.YES)
1139             self.assertEqual(len(out_lines), 2)
1140             self.assertEqual(len(err_lines), 1)
1141             self.assertEqual(out_lines[-1], "reformatted f3")
1142             self.assertEqual(
1143                 unstyle(str(report)),
1144                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1145                 " reformat.",
1146             )
1147             self.assertEqual(report.return_code, 123)
1148             report.failed(Path("e2"), "boom")
1149             self.assertEqual(len(out_lines), 2)
1150             self.assertEqual(len(err_lines), 2)
1151             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1152             self.assertEqual(
1153                 unstyle(str(report)),
1154                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1155                 " reformat.",
1156             )
1157             self.assertEqual(report.return_code, 123)
1158             report.path_ignored(Path("wat"), "no match")
1159             self.assertEqual(len(out_lines), 2)
1160             self.assertEqual(len(err_lines), 2)
1161             self.assertEqual(
1162                 unstyle(str(report)),
1163                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1164                 " reformat.",
1165             )
1166             self.assertEqual(report.return_code, 123)
1167             report.done(Path("f4"), black.Changed.NO)
1168             self.assertEqual(len(out_lines), 2)
1169             self.assertEqual(len(err_lines), 2)
1170             self.assertEqual(
1171                 unstyle(str(report)),
1172                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1173                 " reformat.",
1174             )
1175             self.assertEqual(report.return_code, 123)
1176             report.check = True
1177             self.assertEqual(
1178                 unstyle(str(report)),
1179                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1180                 " would fail to reformat.",
1181             )
1182             report.check = False
1183             report.diff = True
1184             self.assertEqual(
1185                 unstyle(str(report)),
1186                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1187                 " would fail to reformat.",
1188             )
1189
1190     def test_lib2to3_parse(self) -> None:
1191         with self.assertRaises(black.InvalidInput):
1192             black.lib2to3_parse("invalid syntax")
1193
1194         straddling = "x + y"
1195         black.lib2to3_parse(straddling)
1196         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1197         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1198         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1199
1200         py2_only = "print x"
1201         black.lib2to3_parse(py2_only)
1202         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1203         with self.assertRaises(black.InvalidInput):
1204             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1205         with self.assertRaises(black.InvalidInput):
1206             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1207
1208         py3_only = "exec(x, end=y)"
1209         black.lib2to3_parse(py3_only)
1210         with self.assertRaises(black.InvalidInput):
1211             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1212         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1213         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1214
1215     def test_get_features_used(self) -> None:
1216         node = black.lib2to3_parse("def f(*, arg): ...\n")
1217         self.assertEqual(black.get_features_used(node), set())
1218         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1219         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1220         node = black.lib2to3_parse("f(*arg,)\n")
1221         self.assertEqual(
1222             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1223         )
1224         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1225         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1226         node = black.lib2to3_parse("123_456\n")
1227         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1228         node = black.lib2to3_parse("123456\n")
1229         self.assertEqual(black.get_features_used(node), set())
1230         source, expected = read_data("function")
1231         node = black.lib2to3_parse(source)
1232         expected_features = {
1233             Feature.TRAILING_COMMA_IN_CALL,
1234             Feature.TRAILING_COMMA_IN_DEF,
1235             Feature.F_STRINGS,
1236         }
1237         self.assertEqual(black.get_features_used(node), expected_features)
1238         node = black.lib2to3_parse(expected)
1239         self.assertEqual(black.get_features_used(node), expected_features)
1240         source, expected = read_data("expression")
1241         node = black.lib2to3_parse(source)
1242         self.assertEqual(black.get_features_used(node), set())
1243         node = black.lib2to3_parse(expected)
1244         self.assertEqual(black.get_features_used(node), set())
1245
1246     def test_get_future_imports(self) -> None:
1247         node = black.lib2to3_parse("\n")
1248         self.assertEqual(set(), black.get_future_imports(node))
1249         node = black.lib2to3_parse("from __future__ import black\n")
1250         self.assertEqual({"black"}, black.get_future_imports(node))
1251         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1252         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1253         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1254         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1255         node = black.lib2to3_parse(
1256             "from __future__ import multiple\nfrom __future__ import imports\n"
1257         )
1258         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1259         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1260         self.assertEqual({"black"}, black.get_future_imports(node))
1261         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1262         self.assertEqual({"black"}, black.get_future_imports(node))
1263         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1264         self.assertEqual(set(), black.get_future_imports(node))
1265         node = black.lib2to3_parse("from some.module import black\n")
1266         self.assertEqual(set(), black.get_future_imports(node))
1267         node = black.lib2to3_parse(
1268             "from __future__ import unicode_literals as _unicode_literals"
1269         )
1270         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1271         node = black.lib2to3_parse(
1272             "from __future__ import unicode_literals as _lol, print"
1273         )
1274         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1275
1276     def test_debug_visitor(self) -> None:
1277         source, _ = read_data("debug_visitor.py")
1278         expected, _ = read_data("debug_visitor.out")
1279         out_lines = []
1280         err_lines = []
1281
1282         def out(msg: str, **kwargs: Any) -> None:
1283             out_lines.append(msg)
1284
1285         def err(msg: str, **kwargs: Any) -> None:
1286             err_lines.append(msg)
1287
1288         with patch("black.out", out), patch("black.err", err):
1289             black.DebugVisitor.show(source)
1290         actual = "\n".join(out_lines) + "\n"
1291         log_name = ""
1292         if expected != actual:
1293             log_name = black.dump_to_file(*out_lines)
1294         self.assertEqual(
1295             expected,
1296             actual,
1297             f"AST print out is different. Actual version dumped to {log_name}",
1298         )
1299
1300     def test_format_file_contents(self) -> None:
1301         empty = ""
1302         mode = DEFAULT_MODE
1303         with self.assertRaises(black.NothingChanged):
1304             black.format_file_contents(empty, mode=mode, fast=False)
1305         just_nl = "\n"
1306         with self.assertRaises(black.NothingChanged):
1307             black.format_file_contents(just_nl, mode=mode, fast=False)
1308         same = "j = [1, 2, 3]\n"
1309         with self.assertRaises(black.NothingChanged):
1310             black.format_file_contents(same, mode=mode, fast=False)
1311         different = "j = [1,2,3]"
1312         expected = same
1313         actual = black.format_file_contents(different, mode=mode, fast=False)
1314         self.assertEqual(expected, actual)
1315         invalid = "return if you can"
1316         with self.assertRaises(black.InvalidInput) as e:
1317             black.format_file_contents(invalid, mode=mode, fast=False)
1318         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1319
1320     def test_endmarker(self) -> None:
1321         n = black.lib2to3_parse("\n")
1322         self.assertEqual(n.type, black.syms.file_input)
1323         self.assertEqual(len(n.children), 1)
1324         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1325
1326     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1327     def test_assertFormatEqual(self) -> None:
1328         out_lines = []
1329         err_lines = []
1330
1331         def out(msg: str, **kwargs: Any) -> None:
1332             out_lines.append(msg)
1333
1334         def err(msg: str, **kwargs: Any) -> None:
1335             err_lines.append(msg)
1336
1337         with patch("black.out", out), patch("black.err", err):
1338             with self.assertRaises(AssertionError):
1339                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1340
1341         out_str = "".join(out_lines)
1342         self.assertTrue("Expected tree:" in out_str)
1343         self.assertTrue("Actual tree:" in out_str)
1344         self.assertEqual("".join(err_lines), "")
1345
1346     def test_cache_broken_file(self) -> None:
1347         mode = DEFAULT_MODE
1348         with cache_dir() as workspace:
1349             cache_file = black.get_cache_file(mode)
1350             with cache_file.open("w") as fobj:
1351                 fobj.write("this is not a pickle")
1352             self.assertEqual(black.read_cache(mode), {})
1353             src = (workspace / "test.py").resolve()
1354             with src.open("w") as fobj:
1355                 fobj.write("print('hello')")
1356             self.invokeBlack([str(src)])
1357             cache = black.read_cache(mode)
1358             self.assertIn(src, cache)
1359
1360     def test_cache_single_file_already_cached(self) -> None:
1361         mode = DEFAULT_MODE
1362         with cache_dir() as workspace:
1363             src = (workspace / "test.py").resolve()
1364             with src.open("w") as fobj:
1365                 fobj.write("print('hello')")
1366             black.write_cache({}, [src], mode)
1367             self.invokeBlack([str(src)])
1368             with src.open("r") as fobj:
1369                 self.assertEqual(fobj.read(), "print('hello')")
1370
1371     @event_loop()
1372     def test_cache_multiple_files(self) -> None:
1373         mode = DEFAULT_MODE
1374         with cache_dir() as workspace, patch(
1375             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1376         ):
1377             one = (workspace / "one.py").resolve()
1378             with one.open("w") as fobj:
1379                 fobj.write("print('hello')")
1380             two = (workspace / "two.py").resolve()
1381             with two.open("w") as fobj:
1382                 fobj.write("print('hello')")
1383             black.write_cache({}, [one], mode)
1384             self.invokeBlack([str(workspace)])
1385             with one.open("r") as fobj:
1386                 self.assertEqual(fobj.read(), "print('hello')")
1387             with two.open("r") as fobj:
1388                 self.assertEqual(fobj.read(), 'print("hello")\n')
1389             cache = black.read_cache(mode)
1390             self.assertIn(one, cache)
1391             self.assertIn(two, cache)
1392
1393     def test_no_cache_when_writeback_diff(self) -> None:
1394         mode = DEFAULT_MODE
1395         with cache_dir() as workspace:
1396             src = (workspace / "test.py").resolve()
1397             with src.open("w") as fobj:
1398                 fobj.write("print('hello')")
1399             with patch("black.read_cache") as read_cache, patch(
1400                 "black.write_cache"
1401             ) as write_cache:
1402                 self.invokeBlack([str(src), "--diff"])
1403                 cache_file = black.get_cache_file(mode)
1404                 self.assertFalse(cache_file.exists())
1405                 write_cache.assert_not_called()
1406                 read_cache.assert_not_called()
1407
1408     def test_no_cache_when_writeback_color_diff(self) -> None:
1409         mode = DEFAULT_MODE
1410         with cache_dir() as workspace:
1411             src = (workspace / "test.py").resolve()
1412             with src.open("w") as fobj:
1413                 fobj.write("print('hello')")
1414             with patch("black.read_cache") as read_cache, patch(
1415                 "black.write_cache"
1416             ) as write_cache:
1417                 self.invokeBlack([str(src), "--diff", "--color"])
1418                 cache_file = black.get_cache_file(mode)
1419                 self.assertFalse(cache_file.exists())
1420                 write_cache.assert_not_called()
1421                 read_cache.assert_not_called()
1422
1423     @event_loop()
1424     def test_output_locking_when_writeback_diff(self) -> None:
1425         with cache_dir() as workspace:
1426             for tag in range(0, 4):
1427                 src = (workspace / f"test{tag}.py").resolve()
1428                 with src.open("w") as fobj:
1429                     fobj.write("print('hello')")
1430             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1431                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1432                 # this isn't quite doing what we want, but if it _isn't_
1433                 # called then we cannot be using the lock it provides
1434                 mgr.assert_called()
1435
1436     @event_loop()
1437     def test_output_locking_when_writeback_color_diff(self) -> None:
1438         with cache_dir() as workspace:
1439             for tag in range(0, 4):
1440                 src = (workspace / f"test{tag}.py").resolve()
1441                 with src.open("w") as fobj:
1442                     fobj.write("print('hello')")
1443             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1444                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1445                 # this isn't quite doing what we want, but if it _isn't_
1446                 # called then we cannot be using the lock it provides
1447                 mgr.assert_called()
1448
1449     def test_no_cache_when_stdin(self) -> None:
1450         mode = DEFAULT_MODE
1451         with cache_dir():
1452             result = CliRunner().invoke(
1453                 black.main, ["-"], input=BytesIO(b"print('hello')")
1454             )
1455             self.assertEqual(result.exit_code, 0)
1456             cache_file = black.get_cache_file(mode)
1457             self.assertFalse(cache_file.exists())
1458
1459     def test_read_cache_no_cachefile(self) -> None:
1460         mode = DEFAULT_MODE
1461         with cache_dir():
1462             self.assertEqual(black.read_cache(mode), {})
1463
1464     def test_write_cache_read_cache(self) -> None:
1465         mode = DEFAULT_MODE
1466         with cache_dir() as workspace:
1467             src = (workspace / "test.py").resolve()
1468             src.touch()
1469             black.write_cache({}, [src], mode)
1470             cache = black.read_cache(mode)
1471             self.assertIn(src, cache)
1472             self.assertEqual(cache[src], black.get_cache_info(src))
1473
1474     def test_filter_cached(self) -> None:
1475         with TemporaryDirectory() as workspace:
1476             path = Path(workspace)
1477             uncached = (path / "uncached").resolve()
1478             cached = (path / "cached").resolve()
1479             cached_but_changed = (path / "changed").resolve()
1480             uncached.touch()
1481             cached.touch()
1482             cached_but_changed.touch()
1483             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1484             todo, done = black.filter_cached(
1485                 cache, {uncached, cached, cached_but_changed}
1486             )
1487             self.assertEqual(todo, {uncached, cached_but_changed})
1488             self.assertEqual(done, {cached})
1489
1490     def test_write_cache_creates_directory_if_needed(self) -> None:
1491         mode = DEFAULT_MODE
1492         with cache_dir(exists=False) as workspace:
1493             self.assertFalse(workspace.exists())
1494             black.write_cache({}, [], mode)
1495             self.assertTrue(workspace.exists())
1496
1497     @event_loop()
1498     def test_failed_formatting_does_not_get_cached(self) -> None:
1499         mode = DEFAULT_MODE
1500         with cache_dir() as workspace, patch(
1501             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1502         ):
1503             failing = (workspace / "failing.py").resolve()
1504             with failing.open("w") as fobj:
1505                 fobj.write("not actually python")
1506             clean = (workspace / "clean.py").resolve()
1507             with clean.open("w") as fobj:
1508                 fobj.write('print("hello")\n')
1509             self.invokeBlack([str(workspace)], exit_code=123)
1510             cache = black.read_cache(mode)
1511             self.assertNotIn(failing, cache)
1512             self.assertIn(clean, cache)
1513
1514     def test_write_cache_write_fail(self) -> None:
1515         mode = DEFAULT_MODE
1516         with cache_dir(), patch.object(Path, "open") as mock:
1517             mock.side_effect = OSError
1518             black.write_cache({}, [], mode)
1519
1520     @event_loop()
1521     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1522     def test_works_in_mono_process_only_environment(self) -> None:
1523         with cache_dir() as workspace:
1524             for f in [
1525                 (workspace / "one.py").resolve(),
1526                 (workspace / "two.py").resolve(),
1527             ]:
1528                 f.write_text('print("hello")\n')
1529             self.invokeBlack([str(workspace)])
1530
1531     @event_loop()
1532     def test_check_diff_use_together(self) -> None:
1533         with cache_dir():
1534             # Files which will be reformatted.
1535             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1536             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1537             # Files which will not be reformatted.
1538             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1539             self.invokeBlack([str(src2), "--diff", "--check"])
1540             # Multi file command.
1541             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1542
1543     def test_no_files(self) -> None:
1544         with cache_dir():
1545             # Without an argument, black exits with error code 0.
1546             self.invokeBlack([])
1547
1548     def test_broken_symlink(self) -> None:
1549         with cache_dir() as workspace:
1550             symlink = workspace / "broken_link.py"
1551             try:
1552                 symlink.symlink_to("nonexistent.py")
1553             except OSError as e:
1554                 self.skipTest(f"Can't create symlinks: {e}")
1555             self.invokeBlack([str(workspace.resolve())])
1556
1557     def test_read_cache_line_lengths(self) -> None:
1558         mode = DEFAULT_MODE
1559         short_mode = replace(DEFAULT_MODE, line_length=1)
1560         with cache_dir() as workspace:
1561             path = (workspace / "file.py").resolve()
1562             path.touch()
1563             black.write_cache({}, [path], mode)
1564             one = black.read_cache(mode)
1565             self.assertIn(path, one)
1566             two = black.read_cache(short_mode)
1567             self.assertNotIn(path, two)
1568
1569     def test_tricky_unicode_symbols(self) -> None:
1570         source, expected = read_data("tricky_unicode_symbols")
1571         actual = fs(source)
1572         self.assertFormatEqual(expected, actual)
1573         black.assert_equivalent(source, actual)
1574         black.assert_stable(source, actual, DEFAULT_MODE)
1575
1576     def test_single_file_force_pyi(self) -> None:
1577         reg_mode = DEFAULT_MODE
1578         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1579         contents, expected = read_data("force_pyi")
1580         with cache_dir() as workspace:
1581             path = (workspace / "file.py").resolve()
1582             with open(path, "w") as fh:
1583                 fh.write(contents)
1584             self.invokeBlack([str(path), "--pyi"])
1585             with open(path, "r") as fh:
1586                 actual = fh.read()
1587             # verify cache with --pyi is separate
1588             pyi_cache = black.read_cache(pyi_mode)
1589             self.assertIn(path, pyi_cache)
1590             normal_cache = black.read_cache(reg_mode)
1591             self.assertNotIn(path, normal_cache)
1592         self.assertEqual(actual, expected)
1593
1594     @event_loop()
1595     def test_multi_file_force_pyi(self) -> None:
1596         reg_mode = DEFAULT_MODE
1597         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1598         contents, expected = read_data("force_pyi")
1599         with cache_dir() as workspace:
1600             paths = [
1601                 (workspace / "file1.py").resolve(),
1602                 (workspace / "file2.py").resolve(),
1603             ]
1604             for path in paths:
1605                 with open(path, "w") as fh:
1606                     fh.write(contents)
1607             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1608             for path in paths:
1609                 with open(path, "r") as fh:
1610                     actual = fh.read()
1611                 self.assertEqual(actual, expected)
1612             # verify cache with --pyi is separate
1613             pyi_cache = black.read_cache(pyi_mode)
1614             normal_cache = black.read_cache(reg_mode)
1615             for path in paths:
1616                 self.assertIn(path, pyi_cache)
1617                 self.assertNotIn(path, normal_cache)
1618
1619     def test_pipe_force_pyi(self) -> None:
1620         source, expected = read_data("force_pyi")
1621         result = CliRunner().invoke(
1622             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1623         )
1624         self.assertEqual(result.exit_code, 0)
1625         actual = result.output
1626         self.assertFormatEqual(actual, expected)
1627
1628     def test_single_file_force_py36(self) -> None:
1629         reg_mode = DEFAULT_MODE
1630         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1631         source, expected = read_data("force_py36")
1632         with cache_dir() as workspace:
1633             path = (workspace / "file.py").resolve()
1634             with open(path, "w") as fh:
1635                 fh.write(source)
1636             self.invokeBlack([str(path), *PY36_ARGS])
1637             with open(path, "r") as fh:
1638                 actual = fh.read()
1639             # verify cache with --target-version is separate
1640             py36_cache = black.read_cache(py36_mode)
1641             self.assertIn(path, py36_cache)
1642             normal_cache = black.read_cache(reg_mode)
1643             self.assertNotIn(path, normal_cache)
1644         self.assertEqual(actual, expected)
1645
1646     @event_loop()
1647     def test_multi_file_force_py36(self) -> None:
1648         reg_mode = DEFAULT_MODE
1649         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1650         source, expected = read_data("force_py36")
1651         with cache_dir() as workspace:
1652             paths = [
1653                 (workspace / "file1.py").resolve(),
1654                 (workspace / "file2.py").resolve(),
1655             ]
1656             for path in paths:
1657                 with open(path, "w") as fh:
1658                     fh.write(source)
1659             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1660             for path in paths:
1661                 with open(path, "r") as fh:
1662                     actual = fh.read()
1663                 self.assertEqual(actual, expected)
1664             # verify cache with --target-version is separate
1665             pyi_cache = black.read_cache(py36_mode)
1666             normal_cache = black.read_cache(reg_mode)
1667             for path in paths:
1668                 self.assertIn(path, pyi_cache)
1669                 self.assertNotIn(path, normal_cache)
1670
1671     def test_collections(self) -> None:
1672         source, expected = read_data("collections")
1673         actual = fs(source)
1674         self.assertFormatEqual(expected, actual)
1675         black.assert_equivalent(source, actual)
1676         black.assert_stable(source, actual, DEFAULT_MODE)
1677
1678     def test_pipe_force_py36(self) -> None:
1679         source, expected = read_data("force_py36")
1680         result = CliRunner().invoke(
1681             black.main,
1682             ["-", "-q", "--target-version=py36"],
1683             input=BytesIO(source.encode("utf8")),
1684         )
1685         self.assertEqual(result.exit_code, 0)
1686         actual = result.output
1687         self.assertFormatEqual(actual, expected)
1688
1689     def test_include_exclude(self) -> None:
1690         path = THIS_DIR / "data" / "include_exclude_tests"
1691         include = re.compile(r"\.pyi?$")
1692         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1693         report = black.Report()
1694         gitignore = PathSpec.from_lines("gitwildmatch", [])
1695         sources: List[Path] = []
1696         expected = [
1697             Path(path / "b/dont_exclude/a.py"),
1698             Path(path / "b/dont_exclude/a.pyi"),
1699         ]
1700         this_abs = THIS_DIR.resolve()
1701         sources.extend(
1702             black.gen_python_files(
1703                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1704             )
1705         )
1706         self.assertEqual(sorted(expected), sorted(sources))
1707
1708     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1709     def test_exclude_for_issue_1572(self) -> None:
1710         # Exclude shouldn't touch files that were explicitly given to Black through the
1711         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1712         # https://github.com/psf/black/issues/1572
1713         path = THIS_DIR / "data" / "include_exclude_tests"
1714         include = ""
1715         exclude = r"/exclude/|a\.py"
1716         src = str(path / "b/exclude/a.py")
1717         report = black.Report()
1718         expected = [Path(path / "b/exclude/a.py")]
1719         sources = list(
1720             black.get_sources(
1721                 ctx=FakeContext(),
1722                 src=(src,),
1723                 quiet=True,
1724                 verbose=False,
1725                 include=include,
1726                 exclude=exclude,
1727                 force_exclude=None,
1728                 report=report,
1729             )
1730         )
1731         self.assertEqual(sorted(expected), sorted(sources))
1732
1733     def test_gitignore_exclude(self) -> None:
1734         path = THIS_DIR / "data" / "include_exclude_tests"
1735         include = re.compile(r"\.pyi?$")
1736         exclude = re.compile(r"")
1737         report = black.Report()
1738         gitignore = PathSpec.from_lines(
1739             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1740         )
1741         sources: List[Path] = []
1742         expected = [
1743             Path(path / "b/dont_exclude/a.py"),
1744             Path(path / "b/dont_exclude/a.pyi"),
1745         ]
1746         this_abs = THIS_DIR.resolve()
1747         sources.extend(
1748             black.gen_python_files(
1749                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1750             )
1751         )
1752         self.assertEqual(sorted(expected), sorted(sources))
1753
1754     def test_empty_include(self) -> None:
1755         path = THIS_DIR / "data" / "include_exclude_tests"
1756         report = black.Report()
1757         gitignore = PathSpec.from_lines("gitwildmatch", [])
1758         empty = re.compile(r"")
1759         sources: List[Path] = []
1760         expected = [
1761             Path(path / "b/exclude/a.pie"),
1762             Path(path / "b/exclude/a.py"),
1763             Path(path / "b/exclude/a.pyi"),
1764             Path(path / "b/dont_exclude/a.pie"),
1765             Path(path / "b/dont_exclude/a.py"),
1766             Path(path / "b/dont_exclude/a.pyi"),
1767             Path(path / "b/.definitely_exclude/a.pie"),
1768             Path(path / "b/.definitely_exclude/a.py"),
1769             Path(path / "b/.definitely_exclude/a.pyi"),
1770         ]
1771         this_abs = THIS_DIR.resolve()
1772         sources.extend(
1773             black.gen_python_files(
1774                 path.iterdir(),
1775                 this_abs,
1776                 empty,
1777                 re.compile(black.DEFAULT_EXCLUDES),
1778                 None,
1779                 report,
1780                 gitignore,
1781             )
1782         )
1783         self.assertEqual(sorted(expected), sorted(sources))
1784
1785     def test_empty_exclude(self) -> None:
1786         path = THIS_DIR / "data" / "include_exclude_tests"
1787         report = black.Report()
1788         gitignore = PathSpec.from_lines("gitwildmatch", [])
1789         empty = re.compile(r"")
1790         sources: List[Path] = []
1791         expected = [
1792             Path(path / "b/dont_exclude/a.py"),
1793             Path(path / "b/dont_exclude/a.pyi"),
1794             Path(path / "b/exclude/a.py"),
1795             Path(path / "b/exclude/a.pyi"),
1796             Path(path / "b/.definitely_exclude/a.py"),
1797             Path(path / "b/.definitely_exclude/a.pyi"),
1798         ]
1799         this_abs = THIS_DIR.resolve()
1800         sources.extend(
1801             black.gen_python_files(
1802                 path.iterdir(),
1803                 this_abs,
1804                 re.compile(black.DEFAULT_INCLUDES),
1805                 empty,
1806                 None,
1807                 report,
1808                 gitignore,
1809             )
1810         )
1811         self.assertEqual(sorted(expected), sorted(sources))
1812
1813     def test_invalid_include_exclude(self) -> None:
1814         for option in ["--include", "--exclude"]:
1815             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1816
1817     def test_preserves_line_endings(self) -> None:
1818         with TemporaryDirectory() as workspace:
1819             test_file = Path(workspace) / "test.py"
1820             for nl in ["\n", "\r\n"]:
1821                 contents = nl.join(["def f(  ):", "    pass"])
1822                 test_file.write_bytes(contents.encode())
1823                 ff(test_file, write_back=black.WriteBack.YES)
1824                 updated_contents: bytes = test_file.read_bytes()
1825                 self.assertIn(nl.encode(), updated_contents)
1826                 if nl == "\n":
1827                     self.assertNotIn(b"\r\n", updated_contents)
1828
1829     def test_preserves_line_endings_via_stdin(self) -> None:
1830         for nl in ["\n", "\r\n"]:
1831             contents = nl.join(["def f(  ):", "    pass"])
1832             runner = BlackRunner()
1833             result = runner.invoke(
1834                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1835             )
1836             self.assertEqual(result.exit_code, 0)
1837             output = runner.stdout_bytes
1838             self.assertIn(nl.encode("utf8"), output)
1839             if nl == "\n":
1840                 self.assertNotIn(b"\r\n", output)
1841
1842     def test_assert_equivalent_different_asts(self) -> None:
1843         with self.assertRaises(AssertionError):
1844             black.assert_equivalent("{}", "None")
1845
1846     def test_symlink_out_of_root_directory(self) -> None:
1847         path = MagicMock()
1848         root = THIS_DIR.resolve()
1849         child = MagicMock()
1850         include = re.compile(black.DEFAULT_INCLUDES)
1851         exclude = re.compile(black.DEFAULT_EXCLUDES)
1852         report = black.Report()
1853         gitignore = PathSpec.from_lines("gitwildmatch", [])
1854         # `child` should behave like a symlink which resolved path is clearly
1855         # outside of the `root` directory.
1856         path.iterdir.return_value = [child]
1857         child.resolve.return_value = Path("/a/b/c")
1858         child.as_posix.return_value = "/a/b/c"
1859         child.is_symlink.return_value = True
1860         try:
1861             list(
1862                 black.gen_python_files(
1863                     path.iterdir(), root, include, exclude, None, report, gitignore
1864                 )
1865             )
1866         except ValueError as ve:
1867             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1868         path.iterdir.assert_called_once()
1869         child.resolve.assert_called_once()
1870         child.is_symlink.assert_called_once()
1871         # `child` should behave like a strange file which resolved path is clearly
1872         # outside of the `root` directory.
1873         child.is_symlink.return_value = False
1874         with self.assertRaises(ValueError):
1875             list(
1876                 black.gen_python_files(
1877                     path.iterdir(), root, include, exclude, None, report, gitignore
1878                 )
1879             )
1880         path.iterdir.assert_called()
1881         self.assertEqual(path.iterdir.call_count, 2)
1882         child.resolve.assert_called()
1883         self.assertEqual(child.resolve.call_count, 2)
1884         child.is_symlink.assert_called()
1885         self.assertEqual(child.is_symlink.call_count, 2)
1886
1887     def test_shhh_click(self) -> None:
1888         try:
1889             from click import _unicodefun  # type: ignore
1890         except ModuleNotFoundError:
1891             self.skipTest("Incompatible Click version")
1892         if not hasattr(_unicodefun, "_verify_python3_env"):
1893             self.skipTest("Incompatible Click version")
1894         # First, let's see if Click is crashing with a preferred ASCII charset.
1895         with patch("locale.getpreferredencoding") as gpe:
1896             gpe.return_value = "ASCII"
1897             with self.assertRaises(RuntimeError):
1898                 _unicodefun._verify_python3_env()
1899         # Now, let's silence Click...
1900         black.patch_click()
1901         # ...and confirm it's silent.
1902         with patch("locale.getpreferredencoding") as gpe:
1903             gpe.return_value = "ASCII"
1904             try:
1905                 _unicodefun._verify_python3_env()
1906             except RuntimeError as re:
1907                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1908
1909     def test_root_logger_not_used_directly(self) -> None:
1910         def fail(*args: Any, **kwargs: Any) -> None:
1911             self.fail("Record created with root logger")
1912
1913         with patch.multiple(
1914             logging.root,
1915             debug=fail,
1916             info=fail,
1917             warning=fail,
1918             error=fail,
1919             critical=fail,
1920             log=fail,
1921         ):
1922             ff(THIS_FILE)
1923
1924     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1925     def test_blackd_main(self) -> None:
1926         with patch("blackd.web.run_app"):
1927             result = CliRunner().invoke(blackd.main, [])
1928             if result.exception is not None:
1929                 raise result.exception
1930             self.assertEqual(result.exit_code, 0)
1931
1932     def test_invalid_config_return_code(self) -> None:
1933         tmp_file = Path(black.dump_to_file())
1934         try:
1935             tmp_config = Path(black.dump_to_file())
1936             tmp_config.unlink()
1937             args = ["--config", str(tmp_config), str(tmp_file)]
1938             self.invokeBlack(args, exit_code=2, ignore_config=False)
1939         finally:
1940             tmp_file.unlink()
1941
1942     def test_parse_pyproject_toml(self) -> None:
1943         test_toml_file = THIS_DIR / "test.toml"
1944         config = black.parse_pyproject_toml(str(test_toml_file))
1945         self.assertEqual(config["verbose"], 1)
1946         self.assertEqual(config["check"], "no")
1947         self.assertEqual(config["diff"], "y")
1948         self.assertEqual(config["color"], True)
1949         self.assertEqual(config["line_length"], 79)
1950         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1951         self.assertEqual(config["exclude"], r"\.pyi?$")
1952         self.assertEqual(config["include"], r"\.py?$")
1953
1954     def test_read_pyproject_toml(self) -> None:
1955         test_toml_file = THIS_DIR / "test.toml"
1956         fake_ctx = FakeContext()
1957         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1958         config = fake_ctx.default_map
1959         self.assertEqual(config["verbose"], "1")
1960         self.assertEqual(config["check"], "no")
1961         self.assertEqual(config["diff"], "y")
1962         self.assertEqual(config["color"], "True")
1963         self.assertEqual(config["line_length"], "79")
1964         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1965         self.assertEqual(config["exclude"], r"\.pyi?$")
1966         self.assertEqual(config["include"], r"\.py?$")
1967
1968     def test_find_project_root(self) -> None:
1969         with TemporaryDirectory() as workspace:
1970             root = Path(workspace)
1971             test_dir = root / "test"
1972             test_dir.mkdir()
1973
1974             src_dir = root / "src"
1975             src_dir.mkdir()
1976
1977             root_pyproject = root / "pyproject.toml"
1978             root_pyproject.touch()
1979             src_pyproject = src_dir / "pyproject.toml"
1980             src_pyproject.touch()
1981             src_python = src_dir / "foo.py"
1982             src_python.touch()
1983
1984             self.assertEqual(
1985                 black.find_project_root((src_dir, test_dir)), root.resolve()
1986             )
1987             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1988             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1989
1990     def test_bpo_33660_workaround(self) -> None:
1991         if system() == "Windows":
1992             return
1993
1994         # https://bugs.python.org/issue33660
1995
1996         old_cwd = Path.cwd()
1997         try:
1998             root = Path("/")
1999             os.chdir(str(root))
2000             path = Path("workspace") / "project"
2001             report = black.Report(verbose=True)
2002             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2003             self.assertEqual(normalized_path, "workspace/project")
2004         finally:
2005             os.chdir(str(old_cwd))
2006
2007
2008 class BlackDTestCase(AioHTTPTestCase):
2009     async def get_application(self) -> web.Application:
2010         return blackd.make_app()
2011
2012     # TODO: remove these decorators once the below is released
2013     # https://github.com/aio-libs/aiohttp/pull/3727
2014     @skip_if_exception("ClientOSError")
2015     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2016     @unittest_run_loop
2017     async def test_blackd_request_needs_formatting(self) -> None:
2018         response = await self.client.post("/", data=b"print('hello world')")
2019         self.assertEqual(response.status, 200)
2020         self.assertEqual(response.charset, "utf8")
2021         self.assertEqual(await response.read(), b'print("hello world")\n')
2022
2023     @skip_if_exception("ClientOSError")
2024     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2025     @unittest_run_loop
2026     async def test_blackd_request_no_change(self) -> None:
2027         response = await self.client.post("/", data=b'print("hello world")\n')
2028         self.assertEqual(response.status, 204)
2029         self.assertEqual(await response.read(), b"")
2030
2031     @skip_if_exception("ClientOSError")
2032     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2033     @unittest_run_loop
2034     async def test_blackd_request_syntax_error(self) -> None:
2035         response = await self.client.post("/", data=b"what even ( is")
2036         self.assertEqual(response.status, 400)
2037         content = await response.text()
2038         self.assertTrue(
2039             content.startswith("Cannot parse"),
2040             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
2041         )
2042
2043     @skip_if_exception("ClientOSError")
2044     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2045     @unittest_run_loop
2046     async def test_blackd_unsupported_version(self) -> None:
2047         response = await self.client.post(
2048             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
2049         )
2050         self.assertEqual(response.status, 501)
2051
2052     @skip_if_exception("ClientOSError")
2053     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2054     @unittest_run_loop
2055     async def test_blackd_supported_version(self) -> None:
2056         response = await self.client.post(
2057             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
2058         )
2059         self.assertEqual(response.status, 200)
2060
2061     @skip_if_exception("ClientOSError")
2062     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2063     @unittest_run_loop
2064     async def test_blackd_invalid_python_variant(self) -> None:
2065         async def check(header_value: str, expected_status: int = 400) -> None:
2066             response = await self.client.post(
2067                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2068             )
2069             self.assertEqual(response.status, expected_status)
2070
2071         await check("lol")
2072         await check("ruby3.5")
2073         await check("pyi3.6")
2074         await check("py1.5")
2075         await check("2.8")
2076         await check("py2.8")
2077         await check("3.0")
2078         await check("pypy3.0")
2079         await check("jython3.4")
2080
2081     @skip_if_exception("ClientOSError")
2082     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2083     @unittest_run_loop
2084     async def test_blackd_pyi(self) -> None:
2085         source, expected = read_data("stub.pyi")
2086         response = await self.client.post(
2087             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
2088         )
2089         self.assertEqual(response.status, 200)
2090         self.assertEqual(await response.text(), expected)
2091
2092     @skip_if_exception("ClientOSError")
2093     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2094     @unittest_run_loop
2095     async def test_blackd_diff(self) -> None:
2096         diff_header = re.compile(
2097             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2098         )
2099
2100         source, _ = read_data("blackd_diff.py")
2101         expected, _ = read_data("blackd_diff.diff")
2102
2103         response = await self.client.post(
2104             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2105         )
2106         self.assertEqual(response.status, 200)
2107
2108         actual = await response.text()
2109         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2110         self.assertEqual(actual, expected)
2111
2112     @skip_if_exception("ClientOSError")
2113     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2114     @unittest_run_loop
2115     async def test_blackd_python_variant(self) -> None:
2116         code = (
2117             "def f(\n"
2118             "    and_has_a_bunch_of,\n"
2119             "    very_long_arguments_too,\n"
2120             "    and_lots_of_them_as_well_lol,\n"
2121             "    **and_very_long_keyword_arguments\n"
2122             "):\n"
2123             "    pass\n"
2124         )
2125
2126         async def check(header_value: str, expected_status: int) -> None:
2127             response = await self.client.post(
2128                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2129             )
2130             self.assertEqual(
2131                 response.status, expected_status, msg=await response.text()
2132             )
2133
2134         await check("3.6", 200)
2135         await check("py3.6", 200)
2136         await check("3.6,3.7", 200)
2137         await check("3.6,py3.7", 200)
2138         await check("py36,py37", 200)
2139         await check("36", 200)
2140         await check("3.6.4", 200)
2141
2142         await check("2", 204)
2143         await check("2.7", 204)
2144         await check("py2.7", 204)
2145         await check("3.4", 204)
2146         await check("py3.4", 204)
2147         await check("py34,py36", 204)
2148         await check("34", 204)
2149
2150     @skip_if_exception("ClientOSError")
2151     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2152     @unittest_run_loop
2153     async def test_blackd_line_length(self) -> None:
2154         response = await self.client.post(
2155             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2156         )
2157         self.assertEqual(response.status, 200)
2158
2159     @skip_if_exception("ClientOSError")
2160     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2161     @unittest_run_loop
2162     async def test_blackd_invalid_line_length(self) -> None:
2163         response = await self.client.post(
2164             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2165         )
2166         self.assertEqual(response.status, 400)
2167
2168     @skip_if_exception("ClientOSError")
2169     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2170     @unittest_run_loop
2171     async def test_blackd_response_black_version_header(self) -> None:
2172         response = await self.client.post("/")
2173         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2174
2175
2176 with open(black.__file__, "r", encoding="utf-8") as _bf:
2177     black_source_lines = _bf.readlines()
2178
2179
2180 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2181     """Show function calls `from black/__init__.py` as they happen.
2182
2183     Register this with `sys.settrace()` in a test you're debugging.
2184     """
2185     if event != "call":
2186         return tracefunc
2187
2188     stack = len(inspect.stack()) - 19
2189     stack *= 2
2190     filename = frame.f_code.co_filename
2191     lineno = frame.f_lineno
2192     func_sig_lineno = lineno - 1
2193     funcname = black_source_lines[func_sig_lineno].strip()
2194     while funcname.startswith("@"):
2195         func_sig_lineno += 1
2196         funcname = black_source_lines[func_sig_lineno].strip()
2197     if "black/__init__.py" in filename:
2198         print(f"{' ' * stack}{lineno}:{funcname}")
2199     return tracefunc
2200
2201
2202 if __name__ == "__main__":
2203     unittest.main(module="test_black")