]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Refactor logic for stub empty lines (#2796)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103         # Dummy root, since most of the tests don't care about it
104         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
105
106
107 class FakeParameter(click.Parameter):
108     """A fake click Parameter for when calling functions that need it."""
109
110     def __init__(self) -> None:
111         pass
112
113
114 class BlackRunner(CliRunner):
115     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
116
117     def __init__(self) -> None:
118         super().__init__(mix_stderr=False)
119
120
121 def invokeBlack(
122     args: List[str], exit_code: int = 0, ignore_config: bool = True
123 ) -> None:
124     runner = BlackRunner()
125     if ignore_config:
126         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
127     result = runner.invoke(black.main, args, catch_exceptions=False)
128     assert result.stdout_bytes is not None
129     assert result.stderr_bytes is not None
130     msg = (
131         f"Failed with args: {args}\n"
132         f"stdout: {result.stdout_bytes.decode()!r}\n"
133         f"stderr: {result.stderr_bytes.decode()!r}\n"
134         f"exception: {result.exception}"
135     )
136     assert result.exit_code == exit_code, msg
137
138
139 class BlackTestCase(BlackBaseTestCase):
140     invokeBlack = staticmethod(invokeBlack)
141
142     def test_empty_ff(self) -> None:
143         expected = ""
144         tmp_file = Path(black.dump_to_file())
145         try:
146             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
147             with open(tmp_file, encoding="utf8") as f:
148                 actual = f.read()
149         finally:
150             os.unlink(tmp_file)
151         self.assertFormatEqual(expected, actual)
152
153     def test_experimental_string_processing_warns(self) -> None:
154         self.assertWarns(
155             black.mode.Deprecated, black.Mode, experimental_string_processing=True
156         )
157
158     def test_piping(self) -> None:
159         source, expected = read_data("src/black/__init__", data=False)
160         result = BlackRunner().invoke(
161             black.main,
162             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
163             input=BytesIO(source.encode("utf8")),
164         )
165         self.assertEqual(result.exit_code, 0)
166         self.assertFormatEqual(expected, result.output)
167         if source != result.output:
168             black.assert_equivalent(source, result.output)
169             black.assert_stable(source, result.output, DEFAULT_MODE)
170
171     def test_piping_diff(self) -> None:
172         diff_header = re.compile(
173             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
174             r"\+\d\d\d\d"
175         )
176         source, _ = read_data("expression.py")
177         expected, _ = read_data("expression.diff")
178         config = THIS_DIR / "data" / "empty_pyproject.toml"
179         args = [
180             "-",
181             "--fast",
182             f"--line-length={black.DEFAULT_LINE_LENGTH}",
183             "--diff",
184             f"--config={config}",
185         ]
186         result = BlackRunner().invoke(
187             black.main, args, input=BytesIO(source.encode("utf8"))
188         )
189         self.assertEqual(result.exit_code, 0)
190         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
191         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
192         self.assertEqual(expected, actual)
193
194     def test_piping_diff_with_color(self) -> None:
195         source, _ = read_data("expression.py")
196         config = THIS_DIR / "data" / "empty_pyproject.toml"
197         args = [
198             "-",
199             "--fast",
200             f"--line-length={black.DEFAULT_LINE_LENGTH}",
201             "--diff",
202             "--color",
203             f"--config={config}",
204         ]
205         result = BlackRunner().invoke(
206             black.main, args, input=BytesIO(source.encode("utf8"))
207         )
208         actual = result.output
209         # Again, the contents are checked in a different test, so only look for colors.
210         self.assertIn("\033[1m", actual)
211         self.assertIn("\033[36m", actual)
212         self.assertIn("\033[32m", actual)
213         self.assertIn("\033[31m", actual)
214         self.assertIn("\033[0m", actual)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def _test_wip(self) -> None:
218         source, expected = read_data("wip")
219         sys.settrace(tracefunc)
220         mode = replace(
221             DEFAULT_MODE,
222             experimental_string_processing=False,
223             target_versions={black.TargetVersion.PY38},
224         )
225         actual = fs(source, mode=mode)
226         sys.settrace(None)
227         self.assertFormatEqual(expected, actual)
228         black.assert_equivalent(source, actual)
229         black.assert_stable(source, actual, black.FileMode())
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability1(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens1")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability2(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens2")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @unittest.expectedFailure
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_trailing_comma_optional_parens_stability3(self) -> None:
248         source, _expected = read_data("trailing_comma_optional_parens3")
249         actual = fs(source)
250         black.assert_stable(source, actual, DEFAULT_MODE)
251
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens1")
255         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens2")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens3")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     def test_pep_572_version_detection(self) -> None:
271         source, _ = read_data("pep_572")
272         root = black.lib2to3_parse(source)
273         features = black.get_features_used(root)
274         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
275         versions = black.detect_target_versions(root)
276         self.assertIn(black.TargetVersion.PY38, versions)
277
278     def test_expression_ff(self) -> None:
279         source, expected = read_data("expression")
280         tmp_file = Path(black.dump_to_file(source))
281         try:
282             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
283             with open(tmp_file, encoding="utf8") as f:
284                 actual = f.read()
285         finally:
286             os.unlink(tmp_file)
287         self.assertFormatEqual(expected, actual)
288         with patch("black.dump_to_file", dump_to_stderr):
289             black.assert_equivalent(source, actual)
290             black.assert_stable(source, actual, DEFAULT_MODE)
291
292     def test_expression_diff(self) -> None:
293         source, _ = read_data("expression.py")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         expected, _ = read_data("expression.diff")
296         tmp_file = Path(black.dump_to_file(source))
297         diff_header = re.compile(
298             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
299             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
300         )
301         try:
302             result = BlackRunner().invoke(
303                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
304             )
305             self.assertEqual(result.exit_code, 0)
306         finally:
307             os.unlink(tmp_file)
308         actual = result.output
309         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
310         if expected != actual:
311             dump = black.dump_to_file(actual)
312             msg = (
313                 "Expected diff isn't equal to the actual. If you made changes to"
314                 " expression.py and this is an anticipated difference, overwrite"
315                 f" tests/data/expression.diff with {dump}"
316             )
317             self.assertEqual(expected, actual, msg)
318
319     def test_expression_diff_with_color(self) -> None:
320         source, _ = read_data("expression.py")
321         config = THIS_DIR / "data" / "empty_pyproject.toml"
322         expected, _ = read_data("expression.diff")
323         tmp_file = Path(black.dump_to_file(source))
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
327             )
328         finally:
329             os.unlink(tmp_file)
330         actual = result.output
331         # We check the contents of the diff in `test_expression_diff`. All
332         # we need to check here is that color codes exist in the result.
333         self.assertIn("\033[1m", actual)
334         self.assertIn("\033[36m", actual)
335         self.assertIn("\033[32m", actual)
336         self.assertIn("\033[31m", actual)
337         self.assertIn("\033[0m", actual)
338
339     def test_detect_pos_only_arguments(self) -> None:
340         source, _ = read_data("pep_570")
341         root = black.lib2to3_parse(source)
342         features = black.get_features_used(root)
343         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
344         versions = black.detect_target_versions(root)
345         self.assertIn(black.TargetVersion.PY38, versions)
346
347     @patch("black.dump_to_file", dump_to_stderr)
348     def test_string_quotes(self) -> None:
349         source, expected = read_data("string_quotes")
350         mode = black.Mode(preview=True)
351         assert_format(source, expected, mode)
352         mode = replace(mode, string_normalization=False)
353         not_normalized = fs(source, mode=mode)
354         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
355         black.assert_equivalent(source, not_normalized)
356         black.assert_stable(source, not_normalized, mode=mode)
357
358     def test_skip_magic_trailing_comma(self) -> None:
359         source, _ = read_data("expression.py")
360         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
361         tmp_file = Path(black.dump_to_file(source))
362         diff_header = re.compile(
363             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
364             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
365         )
366         try:
367             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
368             self.assertEqual(result.exit_code, 0)
369         finally:
370             os.unlink(tmp_file)
371         actual = result.output
372         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
373         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
374         if expected != actual:
375             dump = black.dump_to_file(actual)
376             msg = (
377                 "Expected diff isn't equal to the actual. If you made changes to"
378                 " expression.py and this is an anticipated difference, overwrite"
379                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
380             )
381             self.assertEqual(expected, actual, msg)
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_async_as_identifier(self) -> None:
385         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
386         source, expected = read_data("async_as_identifier")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         major, minor = sys.version_info[:2]
390         if major < 3 or (major <= 3 and minor < 7):
391             black.assert_equivalent(source, actual)
392         black.assert_stable(source, actual, DEFAULT_MODE)
393         # ensure black can parse this when the target is 3.6
394         self.invokeBlack([str(source_path), "--target-version", "py36"])
395         # but not on 3.7, because async/await is no longer an identifier
396         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_python37(self) -> None:
400         source_path = (THIS_DIR / "data" / "python37.py").resolve()
401         source, expected = read_data("python37")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         major, minor = sys.version_info[:2]
405         if major > 3 or (major == 3 and minor >= 7):
406             black.assert_equivalent(source, actual)
407         black.assert_stable(source, actual, DEFAULT_MODE)
408         # ensure black can parse this when the target is 3.7
409         self.invokeBlack([str(source_path), "--target-version", "py37"])
410         # but not on 3.6, because we use async as a reserved keyword
411         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
412
413     def test_tab_comment_indentation(self) -> None:
414         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
415         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
416         self.assertFormatEqual(contents_spc, fs(contents_spc))
417         self.assertFormatEqual(contents_spc, fs(contents_tab))
418
419         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
420         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
421         self.assertFormatEqual(contents_spc, fs(contents_spc))
422         self.assertFormatEqual(contents_spc, fs(contents_tab))
423
424         # mixed tabs and spaces (valid Python 2 code)
425         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
426         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
427         self.assertFormatEqual(contents_spc, fs(contents_spc))
428         self.assertFormatEqual(contents_spc, fs(contents_tab))
429
430         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
431         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
432         self.assertFormatEqual(contents_spc, fs(contents_spc))
433         self.assertFormatEqual(contents_spc, fs(contents_tab))
434
435     def test_report_verbose(self) -> None:
436         report = Report(verbose=True)
437         out_lines = []
438         err_lines = []
439
440         def out(msg: str, **kwargs: Any) -> None:
441             out_lines.append(msg)
442
443         def err(msg: str, **kwargs: Any) -> None:
444             err_lines.append(msg)
445
446         with patch("black.output._out", out), patch("black.output._err", err):
447             report.done(Path("f1"), black.Changed.NO)
448             self.assertEqual(len(out_lines), 1)
449             self.assertEqual(len(err_lines), 0)
450             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
451             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
452             self.assertEqual(report.return_code, 0)
453             report.done(Path("f2"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 2)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(out_lines[-1], "reformatted f2")
457             self.assertEqual(
458                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
459             )
460             report.done(Path("f3"), black.Changed.CACHED)
461             self.assertEqual(len(out_lines), 3)
462             self.assertEqual(len(err_lines), 0)
463             self.assertEqual(
464                 out_lines[-1], "f3 wasn't modified on disk since last run."
465             )
466             self.assertEqual(
467                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
468             )
469             self.assertEqual(report.return_code, 0)
470             report.check = True
471             self.assertEqual(report.return_code, 1)
472             report.check = False
473             report.failed(Path("e1"), "boom")
474             self.assertEqual(len(out_lines), 3)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f3"), black.Changed.YES)
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 1)
486             self.assertEqual(out_lines[-1], "reformatted f3")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.failed(Path("e2"), "boom")
494             self.assertEqual(len(out_lines), 4)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.path_ignored(Path("wat"), "no match")
504             self.assertEqual(len(out_lines), 5)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "wat ignored: no match")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.done(Path("f4"), black.Changed.NO)
514             self.assertEqual(len(out_lines), 6)
515             self.assertEqual(len(err_lines), 2)
516             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
520                 " reformat.",
521             )
522             self.assertEqual(report.return_code, 123)
523             report.check = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529             report.check = False
530             report.diff = True
531             self.assertEqual(
532                 unstyle(str(report)),
533                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
534                 " would fail to reformat.",
535             )
536
537     def test_report_quiet(self) -> None:
538         report = Report(quiet=True)
539         out_lines = []
540         err_lines = []
541
542         def out(msg: str, **kwargs: Any) -> None:
543             out_lines.append(msg)
544
545         def err(msg: str, **kwargs: Any) -> None:
546             err_lines.append(msg)
547
548         with patch("black.output._out", out), patch("black.output._err", err):
549             report.done(Path("f1"), black.Changed.NO)
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
553             self.assertEqual(report.return_code, 0)
554             report.done(Path("f2"), black.Changed.YES)
555             self.assertEqual(len(out_lines), 0)
556             self.assertEqual(len(err_lines), 0)
557             self.assertEqual(
558                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
559             )
560             report.done(Path("f3"), black.Changed.CACHED)
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(
564                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
565             )
566             self.assertEqual(report.return_code, 0)
567             report.check = True
568             self.assertEqual(report.return_code, 1)
569             report.check = False
570             report.failed(Path("e1"), "boom")
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
574             self.assertEqual(
575                 unstyle(str(report)),
576                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
577                 " reformat.",
578             )
579             self.assertEqual(report.return_code, 123)
580             report.done(Path("f3"), black.Changed.YES)
581             self.assertEqual(len(out_lines), 0)
582             self.assertEqual(len(err_lines), 1)
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.failed(Path("e2"), "boom")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
596                 " reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.path_ignored(Path("wat"), "no match")
600             self.assertEqual(len(out_lines), 0)
601             self.assertEqual(len(err_lines), 2)
602             self.assertEqual(
603                 unstyle(str(report)),
604                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
605                 " reformat.",
606             )
607             self.assertEqual(report.return_code, 123)
608             report.done(Path("f4"), black.Changed.NO)
609             self.assertEqual(len(out_lines), 0)
610             self.assertEqual(len(err_lines), 2)
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
614                 " reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.check = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623             report.check = False
624             report.diff = True
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
628                 " would fail to reformat.",
629             )
630
631     def test_report_normal(self) -> None:
632         report = black.Report()
633         out_lines = []
634         err_lines = []
635
636         def out(msg: str, **kwargs: Any) -> None:
637             out_lines.append(msg)
638
639         def err(msg: str, **kwargs: Any) -> None:
640             err_lines.append(msg)
641
642         with patch("black.output._out", out), patch("black.output._err", err):
643             report.done(Path("f1"), black.Changed.NO)
644             self.assertEqual(len(out_lines), 0)
645             self.assertEqual(len(err_lines), 0)
646             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
647             self.assertEqual(report.return_code, 0)
648             report.done(Path("f2"), black.Changed.YES)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
654             )
655             report.done(Path("f3"), black.Changed.CACHED)
656             self.assertEqual(len(out_lines), 1)
657             self.assertEqual(len(err_lines), 0)
658             self.assertEqual(out_lines[-1], "reformatted f2")
659             self.assertEqual(
660                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
661             )
662             self.assertEqual(report.return_code, 0)
663             report.check = True
664             self.assertEqual(report.return_code, 1)
665             report.check = False
666             report.failed(Path("e1"), "boom")
667             self.assertEqual(len(out_lines), 1)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.done(Path("f3"), black.Changed.YES)
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 1)
679             self.assertEqual(out_lines[-1], "reformatted f3")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.failed(Path("e2"), "boom")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
693                 " reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.path_ignored(Path("wat"), "no match")
697             self.assertEqual(len(out_lines), 2)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
702                 " reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.done(Path("f4"), black.Changed.NO)
706             self.assertEqual(len(out_lines), 2)
707             self.assertEqual(len(err_lines), 2)
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
711                 " reformat.",
712             )
713             self.assertEqual(report.return_code, 123)
714             report.check = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720             report.check = False
721             report.diff = True
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
725                 " would fail to reformat.",
726             )
727
728     def test_lib2to3_parse(self) -> None:
729         with self.assertRaises(black.InvalidInput):
730             black.lib2to3_parse("invalid syntax")
731
732         straddling = "x + y"
733         black.lib2to3_parse(straddling)
734         black.lib2to3_parse(straddling, {TargetVersion.PY36})
735
736         py2_only = "print x"
737         with self.assertRaises(black.InvalidInput):
738             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
739
740         py3_only = "exec(x, end=y)"
741         black.lib2to3_parse(py3_only)
742         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
743
744     def test_get_features_used_decorator(self) -> None:
745         # Test the feature detection of new decorator syntax
746         # since this makes some test cases of test_get_features_used()
747         # fails if it fails, this is tested first so that a useful case
748         # is identified
749         simples, relaxed = read_data("decorators")
750         # skip explanation comments at the top of the file
751         for simple_test in simples.split("##")[1:]:
752             node = black.lib2to3_parse(simple_test)
753             decorator = str(node.children[0].children[0]).strip()
754             self.assertNotIn(
755                 Feature.RELAXED_DECORATORS,
756                 black.get_features_used(node),
757                 msg=(
758                     f"decorator '{decorator}' follows python<=3.8 syntax"
759                     "but is detected as 3.9+"
760                     # f"The full node is\n{node!r}"
761                 ),
762             )
763         # skip the '# output' comment at the top of the output part
764         for relaxed_test in relaxed.split("##")[1:]:
765             node = black.lib2to3_parse(relaxed_test)
766             decorator = str(node.children[0].children[0]).strip()
767             self.assertIn(
768                 Feature.RELAXED_DECORATORS,
769                 black.get_features_used(node),
770                 msg=(
771                     f"decorator '{decorator}' uses python3.9+ syntax"
772                     "but is detected as python<=3.8"
773                     # f"The full node is\n{node!r}"
774                 ),
775             )
776
777     def test_get_features_used(self) -> None:
778         node = black.lib2to3_parse("def f(*, arg): ...\n")
779         self.assertEqual(black.get_features_used(node), set())
780         node = black.lib2to3_parse("def f(*, arg,): ...\n")
781         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
782         node = black.lib2to3_parse("f(*arg,)\n")
783         self.assertEqual(
784             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
785         )
786         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
787         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
788         node = black.lib2to3_parse("123_456\n")
789         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
790         node = black.lib2to3_parse("123456\n")
791         self.assertEqual(black.get_features_used(node), set())
792         source, expected = read_data("function")
793         node = black.lib2to3_parse(source)
794         expected_features = {
795             Feature.TRAILING_COMMA_IN_CALL,
796             Feature.TRAILING_COMMA_IN_DEF,
797             Feature.F_STRINGS,
798         }
799         self.assertEqual(black.get_features_used(node), expected_features)
800         node = black.lib2to3_parse(expected)
801         self.assertEqual(black.get_features_used(node), expected_features)
802         source, expected = read_data("expression")
803         node = black.lib2to3_parse(source)
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse(expected)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse("lambda a, /, b: ...")
808         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
809         node = black.lib2to3_parse("def fn(a, /, b): ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(): yield a, b")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("def fn(): return a, b")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("def fn(): yield *b, c")
816         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
817         node = black.lib2to3_parse("def fn(): return a, *b, c")
818         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
819         node = black.lib2to3_parse("x = a, *b, c")
820         self.assertEqual(black.get_features_used(node), set())
821         node = black.lib2to3_parse("x: Any = regular")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("x: Any = (regular, regular)")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
826         self.assertEqual(black.get_features_used(node), set())
827         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
828         self.assertEqual(
829             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
830         )
831
832     def test_get_features_used_for_future_flags(self) -> None:
833         for src, features in [
834             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
835             (
836                 "from __future__ import (other, annotations)",
837                 {Feature.FUTURE_ANNOTATIONS},
838             ),
839             ("a = 1 + 2\nfrom something import annotations", set()),
840             ("from __future__ import x, y", set()),
841         ]:
842             with self.subTest(src=src, features=features):
843                 node = black.lib2to3_parse(src)
844                 future_imports = black.get_future_imports(node)
845                 self.assertEqual(
846                     black.get_features_used(node, future_imports=future_imports),
847                     features,
848                 )
849
850     def test_get_future_imports(self) -> None:
851         node = black.lib2to3_parse("\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse("from __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
856         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
858         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse(
860             "from __future__ import multiple\nfrom __future__ import imports\n"
861         )
862         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
863         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
864         self.assertEqual({"black"}, black.get_future_imports(node))
865         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
868         self.assertEqual(set(), black.get_future_imports(node))
869         node = black.lib2to3_parse("from some.module import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse(
872             "from __future__ import unicode_literals as _unicode_literals"
873         )
874         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
875         node = black.lib2to3_parse(
876             "from __future__ import unicode_literals as _lol, print"
877         )
878         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
879
880     @pytest.mark.incompatible_with_mypyc
881     def test_debug_visitor(self) -> None:
882         source, _ = read_data("debug_visitor.py")
883         expected, _ = read_data("debug_visitor.out")
884         out_lines = []
885         err_lines = []
886
887         def out(msg: str, **kwargs: Any) -> None:
888             out_lines.append(msg)
889
890         def err(msg: str, **kwargs: Any) -> None:
891             err_lines.append(msg)
892
893         with patch("black.debug.out", out):
894             DebugVisitor.show(source)
895         actual = "\n".join(out_lines) + "\n"
896         log_name = ""
897         if expected != actual:
898             log_name = black.dump_to_file(*out_lines)
899         self.assertEqual(
900             expected,
901             actual,
902             f"AST print out is different. Actual version dumped to {log_name}",
903         )
904
905     def test_format_file_contents(self) -> None:
906         empty = ""
907         mode = DEFAULT_MODE
908         with self.assertRaises(black.NothingChanged):
909             black.format_file_contents(empty, mode=mode, fast=False)
910         just_nl = "\n"
911         with self.assertRaises(black.NothingChanged):
912             black.format_file_contents(just_nl, mode=mode, fast=False)
913         same = "j = [1, 2, 3]\n"
914         with self.assertRaises(black.NothingChanged):
915             black.format_file_contents(same, mode=mode, fast=False)
916         different = "j = [1,2,3]"
917         expected = same
918         actual = black.format_file_contents(different, mode=mode, fast=False)
919         self.assertEqual(expected, actual)
920         invalid = "return if you can"
921         with self.assertRaises(black.InvalidInput) as e:
922             black.format_file_contents(invalid, mode=mode, fast=False)
923         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
924
925     def test_endmarker(self) -> None:
926         n = black.lib2to3_parse("\n")
927         self.assertEqual(n.type, black.syms.file_input)
928         self.assertEqual(len(n.children), 1)
929         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
930
931     @pytest.mark.incompatible_with_mypyc
932     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
933     def test_assertFormatEqual(self) -> None:
934         out_lines = []
935         err_lines = []
936
937         def out(msg: str, **kwargs: Any) -> None:
938             out_lines.append(msg)
939
940         def err(msg: str, **kwargs: Any) -> None:
941             err_lines.append(msg)
942
943         with patch("black.output._out", out), patch("black.output._err", err):
944             with self.assertRaises(AssertionError):
945                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
946
947         out_str = "".join(out_lines)
948         self.assertIn("Expected tree:", out_str)
949         self.assertIn("Actual tree:", out_str)
950         self.assertEqual("".join(err_lines), "")
951
952     @event_loop()
953     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
954     def test_works_in_mono_process_only_environment(self) -> None:
955         with cache_dir() as workspace:
956             for f in [
957                 (workspace / "one.py").resolve(),
958                 (workspace / "two.py").resolve(),
959             ]:
960                 f.write_text('print("hello")\n')
961             self.invokeBlack([str(workspace)])
962
963     @event_loop()
964     def test_check_diff_use_together(self) -> None:
965         with cache_dir():
966             # Files which will be reformatted.
967             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
968             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
969             # Files which will not be reformatted.
970             src2 = (THIS_DIR / "data" / "composition.py").resolve()
971             self.invokeBlack([str(src2), "--diff", "--check"])
972             # Multi file command.
973             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
974
975     def test_no_files(self) -> None:
976         with cache_dir():
977             # Without an argument, black exits with error code 0.
978             self.invokeBlack([])
979
980     def test_broken_symlink(self) -> None:
981         with cache_dir() as workspace:
982             symlink = workspace / "broken_link.py"
983             try:
984                 symlink.symlink_to("nonexistent.py")
985             except (OSError, NotImplementedError) as e:
986                 self.skipTest(f"Can't create symlinks: {e}")
987             self.invokeBlack([str(workspace.resolve())])
988
989     def test_single_file_force_pyi(self) -> None:
990         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
991         contents, expected = read_data("force_pyi")
992         with cache_dir() as workspace:
993             path = (workspace / "file.py").resolve()
994             with open(path, "w") as fh:
995                 fh.write(contents)
996             self.invokeBlack([str(path), "--pyi"])
997             with open(path, "r") as fh:
998                 actual = fh.read()
999             # verify cache with --pyi is separate
1000             pyi_cache = black.read_cache(pyi_mode)
1001             self.assertIn(str(path), pyi_cache)
1002             normal_cache = black.read_cache(DEFAULT_MODE)
1003             self.assertNotIn(str(path), normal_cache)
1004         self.assertFormatEqual(expected, actual)
1005         black.assert_equivalent(contents, actual)
1006         black.assert_stable(contents, actual, pyi_mode)
1007
1008     @event_loop()
1009     def test_multi_file_force_pyi(self) -> None:
1010         reg_mode = DEFAULT_MODE
1011         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1012         contents, expected = read_data("force_pyi")
1013         with cache_dir() as workspace:
1014             paths = [
1015                 (workspace / "file1.py").resolve(),
1016                 (workspace / "file2.py").resolve(),
1017             ]
1018             for path in paths:
1019                 with open(path, "w") as fh:
1020                     fh.write(contents)
1021             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1022             for path in paths:
1023                 with open(path, "r") as fh:
1024                     actual = fh.read()
1025                 self.assertEqual(actual, expected)
1026             # verify cache with --pyi is separate
1027             pyi_cache = black.read_cache(pyi_mode)
1028             normal_cache = black.read_cache(reg_mode)
1029             for path in paths:
1030                 self.assertIn(str(path), pyi_cache)
1031                 self.assertNotIn(str(path), normal_cache)
1032
1033     def test_pipe_force_pyi(self) -> None:
1034         source, expected = read_data("force_pyi")
1035         result = CliRunner().invoke(
1036             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1037         )
1038         self.assertEqual(result.exit_code, 0)
1039         actual = result.output
1040         self.assertFormatEqual(actual, expected)
1041
1042     def test_single_file_force_py36(self) -> None:
1043         reg_mode = DEFAULT_MODE
1044         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1045         source, expected = read_data("force_py36")
1046         with cache_dir() as workspace:
1047             path = (workspace / "file.py").resolve()
1048             with open(path, "w") as fh:
1049                 fh.write(source)
1050             self.invokeBlack([str(path), *PY36_ARGS])
1051             with open(path, "r") as fh:
1052                 actual = fh.read()
1053             # verify cache with --target-version is separate
1054             py36_cache = black.read_cache(py36_mode)
1055             self.assertIn(str(path), py36_cache)
1056             normal_cache = black.read_cache(reg_mode)
1057             self.assertNotIn(str(path), normal_cache)
1058         self.assertEqual(actual, expected)
1059
1060     @event_loop()
1061     def test_multi_file_force_py36(self) -> None:
1062         reg_mode = DEFAULT_MODE
1063         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1064         source, expected = read_data("force_py36")
1065         with cache_dir() as workspace:
1066             paths = [
1067                 (workspace / "file1.py").resolve(),
1068                 (workspace / "file2.py").resolve(),
1069             ]
1070             for path in paths:
1071                 with open(path, "w") as fh:
1072                     fh.write(source)
1073             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1074             for path in paths:
1075                 with open(path, "r") as fh:
1076                     actual = fh.read()
1077                 self.assertEqual(actual, expected)
1078             # verify cache with --target-version is separate
1079             pyi_cache = black.read_cache(py36_mode)
1080             normal_cache = black.read_cache(reg_mode)
1081             for path in paths:
1082                 self.assertIn(str(path), pyi_cache)
1083                 self.assertNotIn(str(path), normal_cache)
1084
1085     def test_pipe_force_py36(self) -> None:
1086         source, expected = read_data("force_py36")
1087         result = CliRunner().invoke(
1088             black.main,
1089             ["-", "-q", "--target-version=py36"],
1090             input=BytesIO(source.encode("utf8")),
1091         )
1092         self.assertEqual(result.exit_code, 0)
1093         actual = result.output
1094         self.assertFormatEqual(actual, expected)
1095
1096     @pytest.mark.incompatible_with_mypyc
1097     def test_reformat_one_with_stdin(self) -> None:
1098         with patch(
1099             "black.format_stdin_to_stdout",
1100             return_value=lambda *args, **kwargs: black.Changed.YES,
1101         ) as fsts:
1102             report = MagicMock()
1103             path = Path("-")
1104             black.reformat_one(
1105                 path,
1106                 fast=True,
1107                 write_back=black.WriteBack.YES,
1108                 mode=DEFAULT_MODE,
1109                 report=report,
1110             )
1111             fsts.assert_called_once()
1112             report.done.assert_called_with(path, black.Changed.YES)
1113
1114     @pytest.mark.incompatible_with_mypyc
1115     def test_reformat_one_with_stdin_filename(self) -> None:
1116         with patch(
1117             "black.format_stdin_to_stdout",
1118             return_value=lambda *args, **kwargs: black.Changed.YES,
1119         ) as fsts:
1120             report = MagicMock()
1121             p = "foo.py"
1122             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1123             expected = Path(p)
1124             black.reformat_one(
1125                 path,
1126                 fast=True,
1127                 write_back=black.WriteBack.YES,
1128                 mode=DEFAULT_MODE,
1129                 report=report,
1130             )
1131             fsts.assert_called_once_with(
1132                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1133             )
1134             # __BLACK_STDIN_FILENAME__ should have been stripped
1135             report.done.assert_called_with(expected, black.Changed.YES)
1136
1137     @pytest.mark.incompatible_with_mypyc
1138     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1139         with patch(
1140             "black.format_stdin_to_stdout",
1141             return_value=lambda *args, **kwargs: black.Changed.YES,
1142         ) as fsts:
1143             report = MagicMock()
1144             p = "foo.pyi"
1145             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1146             expected = Path(p)
1147             black.reformat_one(
1148                 path,
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=DEFAULT_MODE,
1152                 report=report,
1153             )
1154             fsts.assert_called_once_with(
1155                 fast=True,
1156                 write_back=black.WriteBack.YES,
1157                 mode=replace(DEFAULT_MODE, is_pyi=True),
1158             )
1159             # __BLACK_STDIN_FILENAME__ should have been stripped
1160             report.done.assert_called_with(expected, black.Changed.YES)
1161
1162     @pytest.mark.incompatible_with_mypyc
1163     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1164         with patch(
1165             "black.format_stdin_to_stdout",
1166             return_value=lambda *args, **kwargs: black.Changed.YES,
1167         ) as fsts:
1168             report = MagicMock()
1169             p = "foo.ipynb"
1170             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1171             expected = Path(p)
1172             black.reformat_one(
1173                 path,
1174                 fast=True,
1175                 write_back=black.WriteBack.YES,
1176                 mode=DEFAULT_MODE,
1177                 report=report,
1178             )
1179             fsts.assert_called_once_with(
1180                 fast=True,
1181                 write_back=black.WriteBack.YES,
1182                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1183             )
1184             # __BLACK_STDIN_FILENAME__ should have been stripped
1185             report.done.assert_called_with(expected, black.Changed.YES)
1186
1187     @pytest.mark.incompatible_with_mypyc
1188     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1189         with patch(
1190             "black.format_stdin_to_stdout",
1191             return_value=lambda *args, **kwargs: black.Changed.YES,
1192         ) as fsts:
1193             report = MagicMock()
1194             # Even with an existing file, since we are forcing stdin, black
1195             # should output to stdout and not modify the file inplace
1196             p = Path(str(THIS_DIR / "data/collections.py"))
1197             # Make sure is_file actually returns True
1198             self.assertTrue(p.is_file())
1199             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1200             expected = Path(p)
1201             black.reformat_one(
1202                 path,
1203                 fast=True,
1204                 write_back=black.WriteBack.YES,
1205                 mode=DEFAULT_MODE,
1206                 report=report,
1207             )
1208             fsts.assert_called_once()
1209             # __BLACK_STDIN_FILENAME__ should have been stripped
1210             report.done.assert_called_with(expected, black.Changed.YES)
1211
1212     def test_reformat_one_with_stdin_empty(self) -> None:
1213         output = io.StringIO()
1214         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1215             try:
1216                 black.format_stdin_to_stdout(
1217                     fast=True,
1218                     content="",
1219                     write_back=black.WriteBack.YES,
1220                     mode=DEFAULT_MODE,
1221                 )
1222             except io.UnsupportedOperation:
1223                 pass  # StringIO does not support detach
1224             assert output.getvalue() == ""
1225
1226     def test_invalid_cli_regex(self) -> None:
1227         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1228             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1229
1230     def test_required_version_matches_version(self) -> None:
1231         self.invokeBlack(
1232             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1233         )
1234
1235     def test_required_version_does_not_match_version(self) -> None:
1236         self.invokeBlack(
1237             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1238         )
1239
1240     def test_preserves_line_endings(self) -> None:
1241         with TemporaryDirectory() as workspace:
1242             test_file = Path(workspace) / "test.py"
1243             for nl in ["\n", "\r\n"]:
1244                 contents = nl.join(["def f(  ):", "    pass"])
1245                 test_file.write_bytes(contents.encode())
1246                 ff(test_file, write_back=black.WriteBack.YES)
1247                 updated_contents: bytes = test_file.read_bytes()
1248                 self.assertIn(nl.encode(), updated_contents)
1249                 if nl == "\n":
1250                     self.assertNotIn(b"\r\n", updated_contents)
1251
1252     def test_preserves_line_endings_via_stdin(self) -> None:
1253         for nl in ["\n", "\r\n"]:
1254             contents = nl.join(["def f(  ):", "    pass"])
1255             runner = BlackRunner()
1256             result = runner.invoke(
1257                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1258             )
1259             self.assertEqual(result.exit_code, 0)
1260             output = result.stdout_bytes
1261             self.assertIn(nl.encode("utf8"), output)
1262             if nl == "\n":
1263                 self.assertNotIn(b"\r\n", output)
1264
1265     def test_assert_equivalent_different_asts(self) -> None:
1266         with self.assertRaises(AssertionError):
1267             black.assert_equivalent("{}", "None")
1268
1269     def test_shhh_click(self) -> None:
1270         try:
1271             from click import _unicodefun
1272         except ModuleNotFoundError:
1273             self.skipTest("Incompatible Click version")
1274         if not hasattr(_unicodefun, "_verify_python3_env"):
1275             self.skipTest("Incompatible Click version")
1276         # First, let's see if Click is crashing with a preferred ASCII charset.
1277         with patch("locale.getpreferredencoding") as gpe:
1278             gpe.return_value = "ASCII"
1279             with self.assertRaises(RuntimeError):
1280                 _unicodefun._verify_python3_env()  # type: ignore
1281         # Now, let's silence Click...
1282         black.patch_click()
1283         # ...and confirm it's silent.
1284         with patch("locale.getpreferredencoding") as gpe:
1285             gpe.return_value = "ASCII"
1286             try:
1287                 _unicodefun._verify_python3_env()  # type: ignore
1288             except RuntimeError as re:
1289                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1290
1291     def test_root_logger_not_used_directly(self) -> None:
1292         def fail(*args: Any, **kwargs: Any) -> None:
1293             self.fail("Record created with root logger")
1294
1295         with patch.multiple(
1296             logging.root,
1297             debug=fail,
1298             info=fail,
1299             warning=fail,
1300             error=fail,
1301             critical=fail,
1302             log=fail,
1303         ):
1304             ff(THIS_DIR / "util.py")
1305
1306     def test_invalid_config_return_code(self) -> None:
1307         tmp_file = Path(black.dump_to_file())
1308         try:
1309             tmp_config = Path(black.dump_to_file())
1310             tmp_config.unlink()
1311             args = ["--config", str(tmp_config), str(tmp_file)]
1312             self.invokeBlack(args, exit_code=2, ignore_config=False)
1313         finally:
1314             tmp_file.unlink()
1315
1316     def test_parse_pyproject_toml(self) -> None:
1317         test_toml_file = THIS_DIR / "test.toml"
1318         config = black.parse_pyproject_toml(str(test_toml_file))
1319         self.assertEqual(config["verbose"], 1)
1320         self.assertEqual(config["check"], "no")
1321         self.assertEqual(config["diff"], "y")
1322         self.assertEqual(config["color"], True)
1323         self.assertEqual(config["line_length"], 79)
1324         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1325         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1326         self.assertEqual(config["exclude"], r"\.pyi?$")
1327         self.assertEqual(config["include"], r"\.py?$")
1328
1329     def test_read_pyproject_toml(self) -> None:
1330         test_toml_file = THIS_DIR / "test.toml"
1331         fake_ctx = FakeContext()
1332         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1333         config = fake_ctx.default_map
1334         self.assertEqual(config["verbose"], "1")
1335         self.assertEqual(config["check"], "no")
1336         self.assertEqual(config["diff"], "y")
1337         self.assertEqual(config["color"], "True")
1338         self.assertEqual(config["line_length"], "79")
1339         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1340         self.assertEqual(config["exclude"], r"\.pyi?$")
1341         self.assertEqual(config["include"], r"\.py?$")
1342
1343     @pytest.mark.incompatible_with_mypyc
1344     def test_find_project_root(self) -> None:
1345         with TemporaryDirectory() as workspace:
1346             root = Path(workspace)
1347             test_dir = root / "test"
1348             test_dir.mkdir()
1349
1350             src_dir = root / "src"
1351             src_dir.mkdir()
1352
1353             root_pyproject = root / "pyproject.toml"
1354             root_pyproject.touch()
1355             src_pyproject = src_dir / "pyproject.toml"
1356             src_pyproject.touch()
1357             src_python = src_dir / "foo.py"
1358             src_python.touch()
1359
1360             self.assertEqual(
1361                 black.find_project_root((src_dir, test_dir)),
1362                 (root.resolve(), "pyproject.toml"),
1363             )
1364             self.assertEqual(
1365                 black.find_project_root((src_dir,)),
1366                 (src_dir.resolve(), "pyproject.toml"),
1367             )
1368             self.assertEqual(
1369                 black.find_project_root((src_python,)),
1370                 (src_dir.resolve(), "pyproject.toml"),
1371             )
1372
1373     @patch(
1374         "black.files.find_user_pyproject_toml",
1375         black.files.find_user_pyproject_toml.__wrapped__,
1376     )
1377     def test_find_user_pyproject_toml_linux(self) -> None:
1378         if system() == "Windows":
1379             return
1380
1381         # Test if XDG_CONFIG_HOME is checked
1382         with TemporaryDirectory() as workspace:
1383             tmp_user_config = Path(workspace) / "black"
1384             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1385                 self.assertEqual(
1386                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1387                 )
1388
1389         # Test fallback for XDG_CONFIG_HOME
1390         with patch.dict("os.environ"):
1391             os.environ.pop("XDG_CONFIG_HOME", None)
1392             fallback_user_config = Path("~/.config").expanduser() / "black"
1393             self.assertEqual(
1394                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1395             )
1396
1397     def test_find_user_pyproject_toml_windows(self) -> None:
1398         if system() != "Windows":
1399             return
1400
1401         user_config_path = Path.home() / ".black"
1402         self.assertEqual(
1403             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1404         )
1405
1406     def test_bpo_33660_workaround(self) -> None:
1407         if system() == "Windows":
1408             return
1409
1410         # https://bugs.python.org/issue33660
1411         root = Path("/")
1412         with change_directory(root):
1413             path = Path("workspace") / "project"
1414             report = black.Report(verbose=True)
1415             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1416             self.assertEqual(normalized_path, "workspace/project")
1417
1418     def test_newline_comment_interaction(self) -> None:
1419         source = "class A:\\\r\n# type: ignore\n pass\n"
1420         output = black.format_str(source, mode=DEFAULT_MODE)
1421         black.assert_stable(source, output, mode=DEFAULT_MODE)
1422
1423     def test_bpo_2142_workaround(self) -> None:
1424
1425         # https://bugs.python.org/issue2142
1426
1427         source, _ = read_data("missing_final_newline.py")
1428         # read_data adds a trailing newline
1429         source = source.rstrip()
1430         expected, _ = read_data("missing_final_newline.diff")
1431         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1432         diff_header = re.compile(
1433             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1434             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1435         )
1436         try:
1437             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1438             self.assertEqual(result.exit_code, 0)
1439         finally:
1440             os.unlink(tmp_file)
1441         actual = result.output
1442         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1443         self.assertEqual(actual, expected)
1444
1445     @staticmethod
1446     def compare_results(
1447         result: click.testing.Result, expected_value: str, expected_exit_code: int
1448     ) -> None:
1449         """Helper method to test the value and exit code of a click Result."""
1450         assert (
1451             result.output == expected_value
1452         ), "The output did not match the expected value."
1453         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1454
1455     def test_code_option(self) -> None:
1456         """Test the code option with no changes."""
1457         code = 'print("Hello world")\n'
1458         args = ["--code", code]
1459         result = CliRunner().invoke(black.main, args)
1460
1461         self.compare_results(result, code, 0)
1462
1463     def test_code_option_changed(self) -> None:
1464         """Test the code option when changes are required."""
1465         code = "print('hello world')"
1466         formatted = black.format_str(code, mode=DEFAULT_MODE)
1467
1468         args = ["--code", code]
1469         result = CliRunner().invoke(black.main, args)
1470
1471         self.compare_results(result, formatted, 0)
1472
1473     def test_code_option_check(self) -> None:
1474         """Test the code option when check is passed."""
1475         args = ["--check", "--code", 'print("Hello world")\n']
1476         result = CliRunner().invoke(black.main, args)
1477         self.compare_results(result, "", 0)
1478
1479     def test_code_option_check_changed(self) -> None:
1480         """Test the code option when changes are required, and check is passed."""
1481         args = ["--check", "--code", "print('hello world')"]
1482         result = CliRunner().invoke(black.main, args)
1483         self.compare_results(result, "", 1)
1484
1485     def test_code_option_diff(self) -> None:
1486         """Test the code option when diff is passed."""
1487         code = "print('hello world')"
1488         formatted = black.format_str(code, mode=DEFAULT_MODE)
1489         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1490
1491         args = ["--diff", "--code", code]
1492         result = CliRunner().invoke(black.main, args)
1493
1494         # Remove time from diff
1495         output = DIFF_TIME.sub("", result.output)
1496
1497         assert output == result_diff, "The output did not match the expected value."
1498         assert result.exit_code == 0, "The exit code is incorrect."
1499
1500     def test_code_option_color_diff(self) -> None:
1501         """Test the code option when color and diff are passed."""
1502         code = "print('hello world')"
1503         formatted = black.format_str(code, mode=DEFAULT_MODE)
1504
1505         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1506         result_diff = color_diff(result_diff)
1507
1508         args = ["--diff", "--color", "--code", code]
1509         result = CliRunner().invoke(black.main, args)
1510
1511         # Remove time from diff
1512         output = DIFF_TIME.sub("", result.output)
1513
1514         assert output == result_diff, "The output did not match the expected value."
1515         assert result.exit_code == 0, "The exit code is incorrect."
1516
1517     @pytest.mark.incompatible_with_mypyc
1518     def test_code_option_safe(self) -> None:
1519         """Test that the code option throws an error when the sanity checks fail."""
1520         # Patch black.assert_equivalent to ensure the sanity checks fail
1521         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1522             code = 'print("Hello world")'
1523             error_msg = f"{code}\nerror: cannot format <string>: \n"
1524
1525             args = ["--safe", "--code", code]
1526             result = CliRunner().invoke(black.main, args)
1527
1528             self.compare_results(result, error_msg, 123)
1529
1530     def test_code_option_fast(self) -> None:
1531         """Test that the code option ignores errors when the sanity checks fail."""
1532         # Patch black.assert_equivalent to ensure the sanity checks fail
1533         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1534             code = 'print("Hello world")'
1535             formatted = black.format_str(code, mode=DEFAULT_MODE)
1536
1537             args = ["--fast", "--code", code]
1538             result = CliRunner().invoke(black.main, args)
1539
1540             self.compare_results(result, formatted, 0)
1541
1542     @pytest.mark.incompatible_with_mypyc
1543     def test_code_option_config(self) -> None:
1544         """
1545         Test that the code option finds the pyproject.toml in the current directory.
1546         """
1547         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1548             args = ["--code", "print"]
1549             # This is the only directory known to contain a pyproject.toml
1550             with change_directory(PROJECT_ROOT):
1551                 CliRunner().invoke(black.main, args)
1552                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1553
1554             assert (
1555                 len(parse.mock_calls) >= 1
1556             ), "Expected config parse to be called with the current directory."
1557
1558             _, call_args, _ = parse.mock_calls[0]
1559             assert (
1560                 call_args[0].lower() == str(pyproject_path).lower()
1561             ), "Incorrect config loaded."
1562
1563     @pytest.mark.incompatible_with_mypyc
1564     def test_code_option_parent_config(self) -> None:
1565         """
1566         Test that the code option finds the pyproject.toml in the parent directory.
1567         """
1568         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1569             with change_directory(THIS_DIR):
1570                 args = ["--code", "print"]
1571                 CliRunner().invoke(black.main, args)
1572
1573                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1574                 assert (
1575                     len(parse.mock_calls) >= 1
1576                 ), "Expected config parse to be called with the current directory."
1577
1578                 _, call_args, _ = parse.mock_calls[0]
1579                 assert (
1580                     call_args[0].lower() == str(pyproject_path).lower()
1581                 ), "Incorrect config loaded."
1582
1583     def test_for_handled_unexpected_eof_error(self) -> None:
1584         """
1585         Test that an unexpected EOF SyntaxError is nicely presented.
1586         """
1587         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1588             black.lib2to3_parse("print(", {})
1589
1590         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1591
1592     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1593         with pytest.raises(AssertionError) as err:
1594             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1595
1596         err.match("--safe")
1597         # Unfortunately the SyntaxError message has changed in newer versions so we
1598         # can't match it directly.
1599         err.match("invalid character")
1600         err.match(r"\(<unknown>, line 1\)")
1601
1602
1603 class TestCaching:
1604     def test_get_cache_dir(
1605         self,
1606         tmp_path: Path,
1607         monkeypatch: pytest.MonkeyPatch,
1608     ) -> None:
1609         # Create multiple cache directories
1610         workspace1 = tmp_path / "ws1"
1611         workspace1.mkdir()
1612         workspace2 = tmp_path / "ws2"
1613         workspace2.mkdir()
1614
1615         # Force user_cache_dir to use the temporary directory for easier assertions
1616         patch_user_cache_dir = patch(
1617             target="black.cache.user_cache_dir",
1618             autospec=True,
1619             return_value=str(workspace1),
1620         )
1621
1622         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1623         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1624         with patch_user_cache_dir:
1625             assert get_cache_dir() == workspace1
1626
1627         # If it is set, use the path provided in the env var.
1628         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1629         assert get_cache_dir() == workspace2
1630
1631     def test_cache_broken_file(self) -> None:
1632         mode = DEFAULT_MODE
1633         with cache_dir() as workspace:
1634             cache_file = get_cache_file(mode)
1635             cache_file.write_text("this is not a pickle")
1636             assert black.read_cache(mode) == {}
1637             src = (workspace / "test.py").resolve()
1638             src.write_text("print('hello')")
1639             invokeBlack([str(src)])
1640             cache = black.read_cache(mode)
1641             assert str(src) in cache
1642
1643     def test_cache_single_file_already_cached(self) -> None:
1644         mode = DEFAULT_MODE
1645         with cache_dir() as workspace:
1646             src = (workspace / "test.py").resolve()
1647             src.write_text("print('hello')")
1648             black.write_cache({}, [src], mode)
1649             invokeBlack([str(src)])
1650             assert src.read_text() == "print('hello')"
1651
1652     @event_loop()
1653     def test_cache_multiple_files(self) -> None:
1654         mode = DEFAULT_MODE
1655         with cache_dir() as workspace, patch(
1656             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1657         ):
1658             one = (workspace / "one.py").resolve()
1659             with one.open("w") as fobj:
1660                 fobj.write("print('hello')")
1661             two = (workspace / "two.py").resolve()
1662             with two.open("w") as fobj:
1663                 fobj.write("print('hello')")
1664             black.write_cache({}, [one], mode)
1665             invokeBlack([str(workspace)])
1666             with one.open("r") as fobj:
1667                 assert fobj.read() == "print('hello')"
1668             with two.open("r") as fobj:
1669                 assert fobj.read() == 'print("hello")\n'
1670             cache = black.read_cache(mode)
1671             assert str(one) in cache
1672             assert str(two) in cache
1673
1674     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1675     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1676         mode = DEFAULT_MODE
1677         with cache_dir() as workspace:
1678             src = (workspace / "test.py").resolve()
1679             with src.open("w") as fobj:
1680                 fobj.write("print('hello')")
1681             with patch("black.read_cache") as read_cache, patch(
1682                 "black.write_cache"
1683             ) as write_cache:
1684                 cmd = [str(src), "--diff"]
1685                 if color:
1686                     cmd.append("--color")
1687                 invokeBlack(cmd)
1688                 cache_file = get_cache_file(mode)
1689                 assert cache_file.exists() is False
1690                 write_cache.assert_not_called()
1691                 read_cache.assert_not_called()
1692
1693     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1694     @event_loop()
1695     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1696         with cache_dir() as workspace:
1697             for tag in range(0, 4):
1698                 src = (workspace / f"test{tag}.py").resolve()
1699                 with src.open("w") as fobj:
1700                     fobj.write("print('hello')")
1701             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1702                 cmd = ["--diff", str(workspace)]
1703                 if color:
1704                     cmd.append("--color")
1705                 invokeBlack(cmd, exit_code=0)
1706                 # this isn't quite doing what we want, but if it _isn't_
1707                 # called then we cannot be using the lock it provides
1708                 mgr.assert_called()
1709
1710     def test_no_cache_when_stdin(self) -> None:
1711         mode = DEFAULT_MODE
1712         with cache_dir():
1713             result = CliRunner().invoke(
1714                 black.main, ["-"], input=BytesIO(b"print('hello')")
1715             )
1716             assert not result.exit_code
1717             cache_file = get_cache_file(mode)
1718             assert not cache_file.exists()
1719
1720     def test_read_cache_no_cachefile(self) -> None:
1721         mode = DEFAULT_MODE
1722         with cache_dir():
1723             assert black.read_cache(mode) == {}
1724
1725     def test_write_cache_read_cache(self) -> None:
1726         mode = DEFAULT_MODE
1727         with cache_dir() as workspace:
1728             src = (workspace / "test.py").resolve()
1729             src.touch()
1730             black.write_cache({}, [src], mode)
1731             cache = black.read_cache(mode)
1732             assert str(src) in cache
1733             assert cache[str(src)] == black.get_cache_info(src)
1734
1735     def test_filter_cached(self) -> None:
1736         with TemporaryDirectory() as workspace:
1737             path = Path(workspace)
1738             uncached = (path / "uncached").resolve()
1739             cached = (path / "cached").resolve()
1740             cached_but_changed = (path / "changed").resolve()
1741             uncached.touch()
1742             cached.touch()
1743             cached_but_changed.touch()
1744             cache = {
1745                 str(cached): black.get_cache_info(cached),
1746                 str(cached_but_changed): (0.0, 0),
1747             }
1748             todo, done = black.filter_cached(
1749                 cache, {uncached, cached, cached_but_changed}
1750             )
1751             assert todo == {uncached, cached_but_changed}
1752             assert done == {cached}
1753
1754     def test_write_cache_creates_directory_if_needed(self) -> None:
1755         mode = DEFAULT_MODE
1756         with cache_dir(exists=False) as workspace:
1757             assert not workspace.exists()
1758             black.write_cache({}, [], mode)
1759             assert workspace.exists()
1760
1761     @event_loop()
1762     def test_failed_formatting_does_not_get_cached(self) -> None:
1763         mode = DEFAULT_MODE
1764         with cache_dir() as workspace, patch(
1765             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1766         ):
1767             failing = (workspace / "failing.py").resolve()
1768             with failing.open("w") as fobj:
1769                 fobj.write("not actually python")
1770             clean = (workspace / "clean.py").resolve()
1771             with clean.open("w") as fobj:
1772                 fobj.write('print("hello")\n')
1773             invokeBlack([str(workspace)], exit_code=123)
1774             cache = black.read_cache(mode)
1775             assert str(failing) not in cache
1776             assert str(clean) in cache
1777
1778     def test_write_cache_write_fail(self) -> None:
1779         mode = DEFAULT_MODE
1780         with cache_dir(), patch.object(Path, "open") as mock:
1781             mock.side_effect = OSError
1782             black.write_cache({}, [], mode)
1783
1784     def test_read_cache_line_lengths(self) -> None:
1785         mode = DEFAULT_MODE
1786         short_mode = replace(DEFAULT_MODE, line_length=1)
1787         with cache_dir() as workspace:
1788             path = (workspace / "file.py").resolve()
1789             path.touch()
1790             black.write_cache({}, [path], mode)
1791             one = black.read_cache(mode)
1792             assert str(path) in one
1793             two = black.read_cache(short_mode)
1794             assert str(path) not in two
1795
1796
1797 def assert_collected_sources(
1798     src: Sequence[Union[str, Path]],
1799     expected: Sequence[Union[str, Path]],
1800     *,
1801     ctx: Optional[FakeContext] = None,
1802     exclude: Optional[str] = None,
1803     include: Optional[str] = None,
1804     extend_exclude: Optional[str] = None,
1805     force_exclude: Optional[str] = None,
1806     stdin_filename: Optional[str] = None,
1807 ) -> None:
1808     gs_src = tuple(str(Path(s)) for s in src)
1809     gs_expected = [Path(s) for s in expected]
1810     gs_exclude = None if exclude is None else compile_pattern(exclude)
1811     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1812     gs_extend_exclude = (
1813         None if extend_exclude is None else compile_pattern(extend_exclude)
1814     )
1815     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1816     collected = black.get_sources(
1817         ctx=ctx or FakeContext(),
1818         src=gs_src,
1819         quiet=False,
1820         verbose=False,
1821         include=gs_include,
1822         exclude=gs_exclude,
1823         extend_exclude=gs_extend_exclude,
1824         force_exclude=gs_force_exclude,
1825         report=black.Report(),
1826         stdin_filename=stdin_filename,
1827     )
1828     assert sorted(collected) == sorted(gs_expected)
1829
1830
1831 class TestFileCollection:
1832     def test_include_exclude(self) -> None:
1833         path = THIS_DIR / "data" / "include_exclude_tests"
1834         src = [path]
1835         expected = [
1836             Path(path / "b/dont_exclude/a.py"),
1837             Path(path / "b/dont_exclude/a.pyi"),
1838         ]
1839         assert_collected_sources(
1840             src,
1841             expected,
1842             include=r"\.pyi?$",
1843             exclude=r"/exclude/|/\.definitely_exclude/",
1844         )
1845
1846     def test_gitignore_used_as_default(self) -> None:
1847         base = Path(DATA_DIR / "include_exclude_tests")
1848         expected = [
1849             base / "b/.definitely_exclude/a.py",
1850             base / "b/.definitely_exclude/a.pyi",
1851         ]
1852         src = [base / "b/"]
1853         ctx = FakeContext()
1854         ctx.obj["root"] = base
1855         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1856
1857     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1858     def test_exclude_for_issue_1572(self) -> None:
1859         # Exclude shouldn't touch files that were explicitly given to Black through the
1860         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1861         # https://github.com/psf/black/issues/1572
1862         path = DATA_DIR / "include_exclude_tests"
1863         src = [path / "b/exclude/a.py"]
1864         expected = [path / "b/exclude/a.py"]
1865         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1866
1867     def test_gitignore_exclude(self) -> None:
1868         path = THIS_DIR / "data" / "include_exclude_tests"
1869         include = re.compile(r"\.pyi?$")
1870         exclude = re.compile(r"")
1871         report = black.Report()
1872         gitignore = PathSpec.from_lines(
1873             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1874         )
1875         sources: List[Path] = []
1876         expected = [
1877             Path(path / "b/dont_exclude/a.py"),
1878             Path(path / "b/dont_exclude/a.pyi"),
1879         ]
1880         this_abs = THIS_DIR.resolve()
1881         sources.extend(
1882             black.gen_python_files(
1883                 path.iterdir(),
1884                 this_abs,
1885                 include,
1886                 exclude,
1887                 None,
1888                 None,
1889                 report,
1890                 gitignore,
1891                 verbose=False,
1892                 quiet=False,
1893             )
1894         )
1895         assert sorted(expected) == sorted(sources)
1896
1897     def test_nested_gitignore(self) -> None:
1898         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1899         include = re.compile(r"\.pyi?$")
1900         exclude = re.compile(r"")
1901         root_gitignore = black.files.get_gitignore(path)
1902         report = black.Report()
1903         expected: List[Path] = [
1904             Path(path / "x.py"),
1905             Path(path / "root/b.py"),
1906             Path(path / "root/c.py"),
1907             Path(path / "root/child/c.py"),
1908         ]
1909         this_abs = THIS_DIR.resolve()
1910         sources = list(
1911             black.gen_python_files(
1912                 path.iterdir(),
1913                 this_abs,
1914                 include,
1915                 exclude,
1916                 None,
1917                 None,
1918                 report,
1919                 root_gitignore,
1920                 verbose=False,
1921                 quiet=False,
1922             )
1923         )
1924         assert sorted(expected) == sorted(sources)
1925
1926     def test_invalid_gitignore(self) -> None:
1927         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1928         empty_config = path / "pyproject.toml"
1929         result = BlackRunner().invoke(
1930             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1931         )
1932         assert result.exit_code == 1
1933         assert result.stderr_bytes is not None
1934
1935         gitignore = path / ".gitignore"
1936         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1937
1938     def test_invalid_nested_gitignore(self) -> None:
1939         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1940         empty_config = path / "pyproject.toml"
1941         result = BlackRunner().invoke(
1942             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1943         )
1944         assert result.exit_code == 1
1945         assert result.stderr_bytes is not None
1946
1947         gitignore = path / "a" / ".gitignore"
1948         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1949
1950     def test_empty_include(self) -> None:
1951         path = DATA_DIR / "include_exclude_tests"
1952         src = [path]
1953         expected = [
1954             Path(path / "b/exclude/a.pie"),
1955             Path(path / "b/exclude/a.py"),
1956             Path(path / "b/exclude/a.pyi"),
1957             Path(path / "b/dont_exclude/a.pie"),
1958             Path(path / "b/dont_exclude/a.py"),
1959             Path(path / "b/dont_exclude/a.pyi"),
1960             Path(path / "b/.definitely_exclude/a.pie"),
1961             Path(path / "b/.definitely_exclude/a.py"),
1962             Path(path / "b/.definitely_exclude/a.pyi"),
1963             Path(path / ".gitignore"),
1964             Path(path / "pyproject.toml"),
1965         ]
1966         # Setting exclude explicitly to an empty string to block .gitignore usage.
1967         assert_collected_sources(src, expected, include="", exclude="")
1968
1969     def test_extend_exclude(self) -> None:
1970         path = DATA_DIR / "include_exclude_tests"
1971         src = [path]
1972         expected = [
1973             Path(path / "b/exclude/a.py"),
1974             Path(path / "b/dont_exclude/a.py"),
1975         ]
1976         assert_collected_sources(
1977             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1978         )
1979
1980     @pytest.mark.incompatible_with_mypyc
1981     def test_symlink_out_of_root_directory(self) -> None:
1982         path = MagicMock()
1983         root = THIS_DIR.resolve()
1984         child = MagicMock()
1985         include = re.compile(black.DEFAULT_INCLUDES)
1986         exclude = re.compile(black.DEFAULT_EXCLUDES)
1987         report = black.Report()
1988         gitignore = PathSpec.from_lines("gitwildmatch", [])
1989         # `child` should behave like a symlink which resolved path is clearly
1990         # outside of the `root` directory.
1991         path.iterdir.return_value = [child]
1992         child.resolve.return_value = Path("/a/b/c")
1993         child.as_posix.return_value = "/a/b/c"
1994         child.is_symlink.return_value = True
1995         try:
1996             list(
1997                 black.gen_python_files(
1998                     path.iterdir(),
1999                     root,
2000                     include,
2001                     exclude,
2002                     None,
2003                     None,
2004                     report,
2005                     gitignore,
2006                     verbose=False,
2007                     quiet=False,
2008                 )
2009             )
2010         except ValueError as ve:
2011             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2012         path.iterdir.assert_called_once()
2013         child.resolve.assert_called_once()
2014         child.is_symlink.assert_called_once()
2015         # `child` should behave like a strange file which resolved path is clearly
2016         # outside of the `root` directory.
2017         child.is_symlink.return_value = False
2018         with pytest.raises(ValueError):
2019             list(
2020                 black.gen_python_files(
2021                     path.iterdir(),
2022                     root,
2023                     include,
2024                     exclude,
2025                     None,
2026                     None,
2027                     report,
2028                     gitignore,
2029                     verbose=False,
2030                     quiet=False,
2031                 )
2032             )
2033         path.iterdir.assert_called()
2034         assert path.iterdir.call_count == 2
2035         child.resolve.assert_called()
2036         assert child.resolve.call_count == 2
2037         child.is_symlink.assert_called()
2038         assert child.is_symlink.call_count == 2
2039
2040     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2041     def test_get_sources_with_stdin(self) -> None:
2042         src = ["-"]
2043         expected = ["-"]
2044         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2045
2046     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2047     def test_get_sources_with_stdin_filename(self) -> None:
2048         src = ["-"]
2049         stdin_filename = str(THIS_DIR / "data/collections.py")
2050         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2051         assert_collected_sources(
2052             src,
2053             expected,
2054             exclude=r"/exclude/a\.py",
2055             stdin_filename=stdin_filename,
2056         )
2057
2058     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2059     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2060         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2061         # file being passed directly. This is the same as
2062         # test_exclude_for_issue_1572
2063         path = DATA_DIR / "include_exclude_tests"
2064         src = ["-"]
2065         stdin_filename = str(path / "b/exclude/a.py")
2066         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2067         assert_collected_sources(
2068             src,
2069             expected,
2070             exclude=r"/exclude/|a\.py",
2071             stdin_filename=stdin_filename,
2072         )
2073
2074     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2075     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2076         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2077         # file being passed directly. This is the same as
2078         # test_exclude_for_issue_1572
2079         src = ["-"]
2080         path = THIS_DIR / "data" / "include_exclude_tests"
2081         stdin_filename = str(path / "b/exclude/a.py")
2082         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2083         assert_collected_sources(
2084             src,
2085             expected,
2086             extend_exclude=r"/exclude/|a\.py",
2087             stdin_filename=stdin_filename,
2088         )
2089
2090     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2091     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2092         # Force exclude should exclude the file when passing it through
2093         # stdin_filename
2094         path = THIS_DIR / "data" / "include_exclude_tests"
2095         stdin_filename = str(path / "b/exclude/a.py")
2096         assert_collected_sources(
2097             src=["-"],
2098             expected=[],
2099             force_exclude=r"/exclude/|a\.py",
2100             stdin_filename=stdin_filename,
2101         )
2102
2103
2104 try:
2105     with open(black.__file__, "r", encoding="utf-8") as _bf:
2106         black_source_lines = _bf.readlines()
2107 except UnicodeDecodeError:
2108     if not black.COMPILED:
2109         raise
2110
2111
2112 def tracefunc(
2113     frame: types.FrameType, event: str, arg: Any
2114 ) -> Callable[[types.FrameType, str, Any], Any]:
2115     """Show function calls `from black/__init__.py` as they happen.
2116
2117     Register this with `sys.settrace()` in a test you're debugging.
2118     """
2119     if event != "call":
2120         return tracefunc
2121
2122     stack = len(inspect.stack()) - 19
2123     stack *= 2
2124     filename = frame.f_code.co_filename
2125     lineno = frame.f_lineno
2126     func_sig_lineno = lineno - 1
2127     funcname = black_source_lines[func_sig_lineno].strip()
2128     while funcname.startswith("@"):
2129         func_sig_lineno += 1
2130         funcname = black_source_lines[func_sig_lineno].strip()
2131     if "black/__init__.py" in filename:
2132         print(f"{' ' * stack}{lineno}:{funcname}")
2133     return tracefunc