]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Make SRC or code mandatory and mutually exclusive (#2360) (#2804)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103         # Dummy root, since most of the tests don't care about it
104         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
105
106
107 class FakeParameter(click.Parameter):
108     """A fake click Parameter for when calling functions that need it."""
109
110     def __init__(self) -> None:
111         pass
112
113
114 class BlackRunner(CliRunner):
115     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
116
117     def __init__(self) -> None:
118         super().__init__(mix_stderr=False)
119
120
121 def invokeBlack(
122     args: List[str], exit_code: int = 0, ignore_config: bool = True
123 ) -> None:
124     runner = BlackRunner()
125     if ignore_config:
126         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
127     result = runner.invoke(black.main, args, catch_exceptions=False)
128     assert result.stdout_bytes is not None
129     assert result.stderr_bytes is not None
130     msg = (
131         f"Failed with args: {args}\n"
132         f"stdout: {result.stdout_bytes.decode()!r}\n"
133         f"stderr: {result.stderr_bytes.decode()!r}\n"
134         f"exception: {result.exception}"
135     )
136     assert result.exit_code == exit_code, msg
137
138
139 class BlackTestCase(BlackBaseTestCase):
140     invokeBlack = staticmethod(invokeBlack)
141
142     def test_empty_ff(self) -> None:
143         expected = ""
144         tmp_file = Path(black.dump_to_file())
145         try:
146             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
147             with open(tmp_file, encoding="utf8") as f:
148                 actual = f.read()
149         finally:
150             os.unlink(tmp_file)
151         self.assertFormatEqual(expected, actual)
152
153     def test_experimental_string_processing_warns(self) -> None:
154         self.assertWarns(
155             black.mode.Deprecated, black.Mode, experimental_string_processing=True
156         )
157
158     def test_piping(self) -> None:
159         source, expected = read_data("src/black/__init__", data=False)
160         result = BlackRunner().invoke(
161             black.main,
162             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
163             input=BytesIO(source.encode("utf8")),
164         )
165         self.assertEqual(result.exit_code, 0)
166         self.assertFormatEqual(expected, result.output)
167         if source != result.output:
168             black.assert_equivalent(source, result.output)
169             black.assert_stable(source, result.output, DEFAULT_MODE)
170
171     def test_piping_diff(self) -> None:
172         diff_header = re.compile(
173             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
174             r"\+\d\d\d\d"
175         )
176         source, _ = read_data("expression.py")
177         expected, _ = read_data("expression.diff")
178         config = THIS_DIR / "data" / "empty_pyproject.toml"
179         args = [
180             "-",
181             "--fast",
182             f"--line-length={black.DEFAULT_LINE_LENGTH}",
183             "--diff",
184             f"--config={config}",
185         ]
186         result = BlackRunner().invoke(
187             black.main, args, input=BytesIO(source.encode("utf8"))
188         )
189         self.assertEqual(result.exit_code, 0)
190         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
191         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
192         self.assertEqual(expected, actual)
193
194     def test_piping_diff_with_color(self) -> None:
195         source, _ = read_data("expression.py")
196         config = THIS_DIR / "data" / "empty_pyproject.toml"
197         args = [
198             "-",
199             "--fast",
200             f"--line-length={black.DEFAULT_LINE_LENGTH}",
201             "--diff",
202             "--color",
203             f"--config={config}",
204         ]
205         result = BlackRunner().invoke(
206             black.main, args, input=BytesIO(source.encode("utf8"))
207         )
208         actual = result.output
209         # Again, the contents are checked in a different test, so only look for colors.
210         self.assertIn("\033[1m", actual)
211         self.assertIn("\033[36m", actual)
212         self.assertIn("\033[32m", actual)
213         self.assertIn("\033[31m", actual)
214         self.assertIn("\033[0m", actual)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def _test_wip(self) -> None:
218         source, expected = read_data("wip")
219         sys.settrace(tracefunc)
220         mode = replace(
221             DEFAULT_MODE,
222             experimental_string_processing=False,
223             target_versions={black.TargetVersion.PY38},
224         )
225         actual = fs(source, mode=mode)
226         sys.settrace(None)
227         self.assertFormatEqual(expected, actual)
228         black.assert_equivalent(source, actual)
229         black.assert_stable(source, actual, black.FileMode())
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability1(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens1")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability2(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens2")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @unittest.expectedFailure
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_trailing_comma_optional_parens_stability3(self) -> None:
248         source, _expected = read_data("trailing_comma_optional_parens3")
249         actual = fs(source)
250         black.assert_stable(source, actual, DEFAULT_MODE)
251
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens1")
255         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens2")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens3")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     def test_pep_572_version_detection(self) -> None:
271         source, _ = read_data("pep_572")
272         root = black.lib2to3_parse(source)
273         features = black.get_features_used(root)
274         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
275         versions = black.detect_target_versions(root)
276         self.assertIn(black.TargetVersion.PY38, versions)
277
278     def test_expression_ff(self) -> None:
279         source, expected = read_data("expression")
280         tmp_file = Path(black.dump_to_file(source))
281         try:
282             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
283             with open(tmp_file, encoding="utf8") as f:
284                 actual = f.read()
285         finally:
286             os.unlink(tmp_file)
287         self.assertFormatEqual(expected, actual)
288         with patch("black.dump_to_file", dump_to_stderr):
289             black.assert_equivalent(source, actual)
290             black.assert_stable(source, actual, DEFAULT_MODE)
291
292     def test_expression_diff(self) -> None:
293         source, _ = read_data("expression.py")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         expected, _ = read_data("expression.diff")
296         tmp_file = Path(black.dump_to_file(source))
297         diff_header = re.compile(
298             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
299             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
300         )
301         try:
302             result = BlackRunner().invoke(
303                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
304             )
305             self.assertEqual(result.exit_code, 0)
306         finally:
307             os.unlink(tmp_file)
308         actual = result.output
309         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
310         if expected != actual:
311             dump = black.dump_to_file(actual)
312             msg = (
313                 "Expected diff isn't equal to the actual. If you made changes to"
314                 " expression.py and this is an anticipated difference, overwrite"
315                 f" tests/data/expression.diff with {dump}"
316             )
317             self.assertEqual(expected, actual, msg)
318
319     def test_expression_diff_with_color(self) -> None:
320         source, _ = read_data("expression.py")
321         config = THIS_DIR / "data" / "empty_pyproject.toml"
322         expected, _ = read_data("expression.diff")
323         tmp_file = Path(black.dump_to_file(source))
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
327             )
328         finally:
329             os.unlink(tmp_file)
330         actual = result.output
331         # We check the contents of the diff in `test_expression_diff`. All
332         # we need to check here is that color codes exist in the result.
333         self.assertIn("\033[1m", actual)
334         self.assertIn("\033[36m", actual)
335         self.assertIn("\033[32m", actual)
336         self.assertIn("\033[31m", actual)
337         self.assertIn("\033[0m", actual)
338
339     def test_detect_pos_only_arguments(self) -> None:
340         source, _ = read_data("pep_570")
341         root = black.lib2to3_parse(source)
342         features = black.get_features_used(root)
343         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
344         versions = black.detect_target_versions(root)
345         self.assertIn(black.TargetVersion.PY38, versions)
346
347     @patch("black.dump_to_file", dump_to_stderr)
348     def test_string_quotes(self) -> None:
349         source, expected = read_data("string_quotes")
350         mode = black.Mode(preview=True)
351         assert_format(source, expected, mode)
352         mode = replace(mode, string_normalization=False)
353         not_normalized = fs(source, mode=mode)
354         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
355         black.assert_equivalent(source, not_normalized)
356         black.assert_stable(source, not_normalized, mode=mode)
357
358     def test_skip_magic_trailing_comma(self) -> None:
359         source, _ = read_data("expression.py")
360         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
361         tmp_file = Path(black.dump_to_file(source))
362         diff_header = re.compile(
363             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
364             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
365         )
366         try:
367             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
368             self.assertEqual(result.exit_code, 0)
369         finally:
370             os.unlink(tmp_file)
371         actual = result.output
372         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
373         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
374         if expected != actual:
375             dump = black.dump_to_file(actual)
376             msg = (
377                 "Expected diff isn't equal to the actual. If you made changes to"
378                 " expression.py and this is an anticipated difference, overwrite"
379                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
380             )
381             self.assertEqual(expected, actual, msg)
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_async_as_identifier(self) -> None:
385         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
386         source, expected = read_data("async_as_identifier")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         major, minor = sys.version_info[:2]
390         if major < 3 or (major <= 3 and minor < 7):
391             black.assert_equivalent(source, actual)
392         black.assert_stable(source, actual, DEFAULT_MODE)
393         # ensure black can parse this when the target is 3.6
394         self.invokeBlack([str(source_path), "--target-version", "py36"])
395         # but not on 3.7, because async/await is no longer an identifier
396         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_python37(self) -> None:
400         source_path = (THIS_DIR / "data" / "python37.py").resolve()
401         source, expected = read_data("python37")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         major, minor = sys.version_info[:2]
405         if major > 3 or (major == 3 and minor >= 7):
406             black.assert_equivalent(source, actual)
407         black.assert_stable(source, actual, DEFAULT_MODE)
408         # ensure black can parse this when the target is 3.7
409         self.invokeBlack([str(source_path), "--target-version", "py37"])
410         # but not on 3.6, because we use async as a reserved keyword
411         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
412
413     def test_tab_comment_indentation(self) -> None:
414         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
415         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
416         self.assertFormatEqual(contents_spc, fs(contents_spc))
417         self.assertFormatEqual(contents_spc, fs(contents_tab))
418
419         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
420         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
421         self.assertFormatEqual(contents_spc, fs(contents_spc))
422         self.assertFormatEqual(contents_spc, fs(contents_tab))
423
424         # mixed tabs and spaces (valid Python 2 code)
425         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
426         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
427         self.assertFormatEqual(contents_spc, fs(contents_spc))
428         self.assertFormatEqual(contents_spc, fs(contents_tab))
429
430         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
431         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
432         self.assertFormatEqual(contents_spc, fs(contents_spc))
433         self.assertFormatEqual(contents_spc, fs(contents_tab))
434
435     def test_report_verbose(self) -> None:
436         report = Report(verbose=True)
437         out_lines = []
438         err_lines = []
439
440         def out(msg: str, **kwargs: Any) -> None:
441             out_lines.append(msg)
442
443         def err(msg: str, **kwargs: Any) -> None:
444             err_lines.append(msg)
445
446         with patch("black.output._out", out), patch("black.output._err", err):
447             report.done(Path("f1"), black.Changed.NO)
448             self.assertEqual(len(out_lines), 1)
449             self.assertEqual(len(err_lines), 0)
450             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
451             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
452             self.assertEqual(report.return_code, 0)
453             report.done(Path("f2"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 2)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(out_lines[-1], "reformatted f2")
457             self.assertEqual(
458                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
459             )
460             report.done(Path("f3"), black.Changed.CACHED)
461             self.assertEqual(len(out_lines), 3)
462             self.assertEqual(len(err_lines), 0)
463             self.assertEqual(
464                 out_lines[-1], "f3 wasn't modified on disk since last run."
465             )
466             self.assertEqual(
467                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
468             )
469             self.assertEqual(report.return_code, 0)
470             report.check = True
471             self.assertEqual(report.return_code, 1)
472             report.check = False
473             report.failed(Path("e1"), "boom")
474             self.assertEqual(len(out_lines), 3)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f3"), black.Changed.YES)
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 1)
486             self.assertEqual(out_lines[-1], "reformatted f3")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.failed(Path("e2"), "boom")
494             self.assertEqual(len(out_lines), 4)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.path_ignored(Path("wat"), "no match")
504             self.assertEqual(len(out_lines), 5)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "wat ignored: no match")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.done(Path("f4"), black.Changed.NO)
514             self.assertEqual(len(out_lines), 6)
515             self.assertEqual(len(err_lines), 2)
516             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
520                 " reformat.",
521             )
522             self.assertEqual(report.return_code, 123)
523             report.check = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529             report.check = False
530             report.diff = True
531             self.assertEqual(
532                 unstyle(str(report)),
533                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
534                 " would fail to reformat.",
535             )
536
537     def test_report_quiet(self) -> None:
538         report = Report(quiet=True)
539         out_lines = []
540         err_lines = []
541
542         def out(msg: str, **kwargs: Any) -> None:
543             out_lines.append(msg)
544
545         def err(msg: str, **kwargs: Any) -> None:
546             err_lines.append(msg)
547
548         with patch("black.output._out", out), patch("black.output._err", err):
549             report.done(Path("f1"), black.Changed.NO)
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
553             self.assertEqual(report.return_code, 0)
554             report.done(Path("f2"), black.Changed.YES)
555             self.assertEqual(len(out_lines), 0)
556             self.assertEqual(len(err_lines), 0)
557             self.assertEqual(
558                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
559             )
560             report.done(Path("f3"), black.Changed.CACHED)
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(
564                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
565             )
566             self.assertEqual(report.return_code, 0)
567             report.check = True
568             self.assertEqual(report.return_code, 1)
569             report.check = False
570             report.failed(Path("e1"), "boom")
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
574             self.assertEqual(
575                 unstyle(str(report)),
576                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
577                 " reformat.",
578             )
579             self.assertEqual(report.return_code, 123)
580             report.done(Path("f3"), black.Changed.YES)
581             self.assertEqual(len(out_lines), 0)
582             self.assertEqual(len(err_lines), 1)
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.failed(Path("e2"), "boom")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
596                 " reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.path_ignored(Path("wat"), "no match")
600             self.assertEqual(len(out_lines), 0)
601             self.assertEqual(len(err_lines), 2)
602             self.assertEqual(
603                 unstyle(str(report)),
604                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
605                 " reformat.",
606             )
607             self.assertEqual(report.return_code, 123)
608             report.done(Path("f4"), black.Changed.NO)
609             self.assertEqual(len(out_lines), 0)
610             self.assertEqual(len(err_lines), 2)
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
614                 " reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.check = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623             report.check = False
624             report.diff = True
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
628                 " would fail to reformat.",
629             )
630
631     def test_report_normal(self) -> None:
632         report = black.Report()
633         out_lines = []
634         err_lines = []
635
636         def out(msg: str, **kwargs: Any) -> None:
637             out_lines.append(msg)
638
639         def err(msg: str, **kwargs: Any) -> None:
640             err_lines.append(msg)
641
642         with patch("black.output._out", out), patch("black.output._err", err):
643             report.done(Path("f1"), black.Changed.NO)
644             self.assertEqual(len(out_lines), 0)
645             self.assertEqual(len(err_lines), 0)
646             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
647             self.assertEqual(report.return_code, 0)
648             report.done(Path("f2"), black.Changed.YES)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
654             )
655             report.done(Path("f3"), black.Changed.CACHED)
656             self.assertEqual(len(out_lines), 1)
657             self.assertEqual(len(err_lines), 0)
658             self.assertEqual(out_lines[-1], "reformatted f2")
659             self.assertEqual(
660                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
661             )
662             self.assertEqual(report.return_code, 0)
663             report.check = True
664             self.assertEqual(report.return_code, 1)
665             report.check = False
666             report.failed(Path("e1"), "boom")
667             self.assertEqual(len(out_lines), 1)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.done(Path("f3"), black.Changed.YES)
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 1)
679             self.assertEqual(out_lines[-1], "reformatted f3")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.failed(Path("e2"), "boom")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
693                 " reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.path_ignored(Path("wat"), "no match")
697             self.assertEqual(len(out_lines), 2)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
702                 " reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.done(Path("f4"), black.Changed.NO)
706             self.assertEqual(len(out_lines), 2)
707             self.assertEqual(len(err_lines), 2)
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
711                 " reformat.",
712             )
713             self.assertEqual(report.return_code, 123)
714             report.check = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720             report.check = False
721             report.diff = True
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
725                 " would fail to reformat.",
726             )
727
728     def test_lib2to3_parse(self) -> None:
729         with self.assertRaises(black.InvalidInput):
730             black.lib2to3_parse("invalid syntax")
731
732         straddling = "x + y"
733         black.lib2to3_parse(straddling)
734         black.lib2to3_parse(straddling, {TargetVersion.PY36})
735
736         py2_only = "print x"
737         with self.assertRaises(black.InvalidInput):
738             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
739
740         py3_only = "exec(x, end=y)"
741         black.lib2to3_parse(py3_only)
742         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
743
744     def test_get_features_used_decorator(self) -> None:
745         # Test the feature detection of new decorator syntax
746         # since this makes some test cases of test_get_features_used()
747         # fails if it fails, this is tested first so that a useful case
748         # is identified
749         simples, relaxed = read_data("decorators")
750         # skip explanation comments at the top of the file
751         for simple_test in simples.split("##")[1:]:
752             node = black.lib2to3_parse(simple_test)
753             decorator = str(node.children[0].children[0]).strip()
754             self.assertNotIn(
755                 Feature.RELAXED_DECORATORS,
756                 black.get_features_used(node),
757                 msg=(
758                     f"decorator '{decorator}' follows python<=3.8 syntax"
759                     "but is detected as 3.9+"
760                     # f"The full node is\n{node!r}"
761                 ),
762             )
763         # skip the '# output' comment at the top of the output part
764         for relaxed_test in relaxed.split("##")[1:]:
765             node = black.lib2to3_parse(relaxed_test)
766             decorator = str(node.children[0].children[0]).strip()
767             self.assertIn(
768                 Feature.RELAXED_DECORATORS,
769                 black.get_features_used(node),
770                 msg=(
771                     f"decorator '{decorator}' uses python3.9+ syntax"
772                     "but is detected as python<=3.8"
773                     # f"The full node is\n{node!r}"
774                 ),
775             )
776
777     def test_get_features_used(self) -> None:
778         node = black.lib2to3_parse("def f(*, arg): ...\n")
779         self.assertEqual(black.get_features_used(node), set())
780         node = black.lib2to3_parse("def f(*, arg,): ...\n")
781         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
782         node = black.lib2to3_parse("f(*arg,)\n")
783         self.assertEqual(
784             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
785         )
786         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
787         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
788         node = black.lib2to3_parse("123_456\n")
789         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
790         node = black.lib2to3_parse("123456\n")
791         self.assertEqual(black.get_features_used(node), set())
792         source, expected = read_data("function")
793         node = black.lib2to3_parse(source)
794         expected_features = {
795             Feature.TRAILING_COMMA_IN_CALL,
796             Feature.TRAILING_COMMA_IN_DEF,
797             Feature.F_STRINGS,
798         }
799         self.assertEqual(black.get_features_used(node), expected_features)
800         node = black.lib2to3_parse(expected)
801         self.assertEqual(black.get_features_used(node), expected_features)
802         source, expected = read_data("expression")
803         node = black.lib2to3_parse(source)
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse(expected)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse("lambda a, /, b: ...")
808         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
809         node = black.lib2to3_parse("def fn(a, /, b): ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(): yield a, b")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("def fn(): return a, b")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("def fn(): yield *b, c")
816         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
817         node = black.lib2to3_parse("def fn(): return a, *b, c")
818         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
819         node = black.lib2to3_parse("x = a, *b, c")
820         self.assertEqual(black.get_features_used(node), set())
821         node = black.lib2to3_parse("x: Any = regular")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("x: Any = (regular, regular)")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
826         self.assertEqual(black.get_features_used(node), set())
827         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
828         self.assertEqual(
829             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
830         )
831
832     def test_get_features_used_for_future_flags(self) -> None:
833         for src, features in [
834             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
835             (
836                 "from __future__ import (other, annotations)",
837                 {Feature.FUTURE_ANNOTATIONS},
838             ),
839             ("a = 1 + 2\nfrom something import annotations", set()),
840             ("from __future__ import x, y", set()),
841         ]:
842             with self.subTest(src=src, features=features):
843                 node = black.lib2to3_parse(src)
844                 future_imports = black.get_future_imports(node)
845                 self.assertEqual(
846                     black.get_features_used(node, future_imports=future_imports),
847                     features,
848                 )
849
850     def test_get_future_imports(self) -> None:
851         node = black.lib2to3_parse("\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse("from __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
856         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
858         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse(
860             "from __future__ import multiple\nfrom __future__ import imports\n"
861         )
862         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
863         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
864         self.assertEqual({"black"}, black.get_future_imports(node))
865         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
868         self.assertEqual(set(), black.get_future_imports(node))
869         node = black.lib2to3_parse("from some.module import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse(
872             "from __future__ import unicode_literals as _unicode_literals"
873         )
874         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
875         node = black.lib2to3_parse(
876             "from __future__ import unicode_literals as _lol, print"
877         )
878         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
879
880     @pytest.mark.incompatible_with_mypyc
881     def test_debug_visitor(self) -> None:
882         source, _ = read_data("debug_visitor.py")
883         expected, _ = read_data("debug_visitor.out")
884         out_lines = []
885         err_lines = []
886
887         def out(msg: str, **kwargs: Any) -> None:
888             out_lines.append(msg)
889
890         def err(msg: str, **kwargs: Any) -> None:
891             err_lines.append(msg)
892
893         with patch("black.debug.out", out):
894             DebugVisitor.show(source)
895         actual = "\n".join(out_lines) + "\n"
896         log_name = ""
897         if expected != actual:
898             log_name = black.dump_to_file(*out_lines)
899         self.assertEqual(
900             expected,
901             actual,
902             f"AST print out is different. Actual version dumped to {log_name}",
903         )
904
905     def test_format_file_contents(self) -> None:
906         empty = ""
907         mode = DEFAULT_MODE
908         with self.assertRaises(black.NothingChanged):
909             black.format_file_contents(empty, mode=mode, fast=False)
910         just_nl = "\n"
911         with self.assertRaises(black.NothingChanged):
912             black.format_file_contents(just_nl, mode=mode, fast=False)
913         same = "j = [1, 2, 3]\n"
914         with self.assertRaises(black.NothingChanged):
915             black.format_file_contents(same, mode=mode, fast=False)
916         different = "j = [1,2,3]"
917         expected = same
918         actual = black.format_file_contents(different, mode=mode, fast=False)
919         self.assertEqual(expected, actual)
920         invalid = "return if you can"
921         with self.assertRaises(black.InvalidInput) as e:
922             black.format_file_contents(invalid, mode=mode, fast=False)
923         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
924
925     def test_endmarker(self) -> None:
926         n = black.lib2to3_parse("\n")
927         self.assertEqual(n.type, black.syms.file_input)
928         self.assertEqual(len(n.children), 1)
929         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
930
931     @pytest.mark.incompatible_with_mypyc
932     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
933     def test_assertFormatEqual(self) -> None:
934         out_lines = []
935         err_lines = []
936
937         def out(msg: str, **kwargs: Any) -> None:
938             out_lines.append(msg)
939
940         def err(msg: str, **kwargs: Any) -> None:
941             err_lines.append(msg)
942
943         with patch("black.output._out", out), patch("black.output._err", err):
944             with self.assertRaises(AssertionError):
945                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
946
947         out_str = "".join(out_lines)
948         self.assertIn("Expected tree:", out_str)
949         self.assertIn("Actual tree:", out_str)
950         self.assertEqual("".join(err_lines), "")
951
952     @event_loop()
953     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
954     def test_works_in_mono_process_only_environment(self) -> None:
955         with cache_dir() as workspace:
956             for f in [
957                 (workspace / "one.py").resolve(),
958                 (workspace / "two.py").resolve(),
959             ]:
960                 f.write_text('print("hello")\n')
961             self.invokeBlack([str(workspace)])
962
963     @event_loop()
964     def test_check_diff_use_together(self) -> None:
965         with cache_dir():
966             # Files which will be reformatted.
967             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
968             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
969             # Files which will not be reformatted.
970             src2 = (THIS_DIR / "data" / "composition.py").resolve()
971             self.invokeBlack([str(src2), "--diff", "--check"])
972             # Multi file command.
973             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
974
975     def test_no_src_fails(self) -> None:
976         with cache_dir():
977             self.invokeBlack([], exit_code=1)
978
979     def test_src_and_code_fails(self) -> None:
980         with cache_dir():
981             self.invokeBlack([".", "-c", "0"], exit_code=1)
982
983     def test_broken_symlink(self) -> None:
984         with cache_dir() as workspace:
985             symlink = workspace / "broken_link.py"
986             try:
987                 symlink.symlink_to("nonexistent.py")
988             except (OSError, NotImplementedError) as e:
989                 self.skipTest(f"Can't create symlinks: {e}")
990             self.invokeBlack([str(workspace.resolve())])
991
992     def test_single_file_force_pyi(self) -> None:
993         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
994         contents, expected = read_data("force_pyi")
995         with cache_dir() as workspace:
996             path = (workspace / "file.py").resolve()
997             with open(path, "w") as fh:
998                 fh.write(contents)
999             self.invokeBlack([str(path), "--pyi"])
1000             with open(path, "r") as fh:
1001                 actual = fh.read()
1002             # verify cache with --pyi is separate
1003             pyi_cache = black.read_cache(pyi_mode)
1004             self.assertIn(str(path), pyi_cache)
1005             normal_cache = black.read_cache(DEFAULT_MODE)
1006             self.assertNotIn(str(path), normal_cache)
1007         self.assertFormatEqual(expected, actual)
1008         black.assert_equivalent(contents, actual)
1009         black.assert_stable(contents, actual, pyi_mode)
1010
1011     @event_loop()
1012     def test_multi_file_force_pyi(self) -> None:
1013         reg_mode = DEFAULT_MODE
1014         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1015         contents, expected = read_data("force_pyi")
1016         with cache_dir() as workspace:
1017             paths = [
1018                 (workspace / "file1.py").resolve(),
1019                 (workspace / "file2.py").resolve(),
1020             ]
1021             for path in paths:
1022                 with open(path, "w") as fh:
1023                     fh.write(contents)
1024             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1025             for path in paths:
1026                 with open(path, "r") as fh:
1027                     actual = fh.read()
1028                 self.assertEqual(actual, expected)
1029             # verify cache with --pyi is separate
1030             pyi_cache = black.read_cache(pyi_mode)
1031             normal_cache = black.read_cache(reg_mode)
1032             for path in paths:
1033                 self.assertIn(str(path), pyi_cache)
1034                 self.assertNotIn(str(path), normal_cache)
1035
1036     def test_pipe_force_pyi(self) -> None:
1037         source, expected = read_data("force_pyi")
1038         result = CliRunner().invoke(
1039             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1040         )
1041         self.assertEqual(result.exit_code, 0)
1042         actual = result.output
1043         self.assertFormatEqual(actual, expected)
1044
1045     def test_single_file_force_py36(self) -> None:
1046         reg_mode = DEFAULT_MODE
1047         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1048         source, expected = read_data("force_py36")
1049         with cache_dir() as workspace:
1050             path = (workspace / "file.py").resolve()
1051             with open(path, "w") as fh:
1052                 fh.write(source)
1053             self.invokeBlack([str(path), *PY36_ARGS])
1054             with open(path, "r") as fh:
1055                 actual = fh.read()
1056             # verify cache with --target-version is separate
1057             py36_cache = black.read_cache(py36_mode)
1058             self.assertIn(str(path), py36_cache)
1059             normal_cache = black.read_cache(reg_mode)
1060             self.assertNotIn(str(path), normal_cache)
1061         self.assertEqual(actual, expected)
1062
1063     @event_loop()
1064     def test_multi_file_force_py36(self) -> None:
1065         reg_mode = DEFAULT_MODE
1066         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1067         source, expected = read_data("force_py36")
1068         with cache_dir() as workspace:
1069             paths = [
1070                 (workspace / "file1.py").resolve(),
1071                 (workspace / "file2.py").resolve(),
1072             ]
1073             for path in paths:
1074                 with open(path, "w") as fh:
1075                     fh.write(source)
1076             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1077             for path in paths:
1078                 with open(path, "r") as fh:
1079                     actual = fh.read()
1080                 self.assertEqual(actual, expected)
1081             # verify cache with --target-version is separate
1082             pyi_cache = black.read_cache(py36_mode)
1083             normal_cache = black.read_cache(reg_mode)
1084             for path in paths:
1085                 self.assertIn(str(path), pyi_cache)
1086                 self.assertNotIn(str(path), normal_cache)
1087
1088     def test_pipe_force_py36(self) -> None:
1089         source, expected = read_data("force_py36")
1090         result = CliRunner().invoke(
1091             black.main,
1092             ["-", "-q", "--target-version=py36"],
1093             input=BytesIO(source.encode("utf8")),
1094         )
1095         self.assertEqual(result.exit_code, 0)
1096         actual = result.output
1097         self.assertFormatEqual(actual, expected)
1098
1099     @pytest.mark.incompatible_with_mypyc
1100     def test_reformat_one_with_stdin(self) -> None:
1101         with patch(
1102             "black.format_stdin_to_stdout",
1103             return_value=lambda *args, **kwargs: black.Changed.YES,
1104         ) as fsts:
1105             report = MagicMock()
1106             path = Path("-")
1107             black.reformat_one(
1108                 path,
1109                 fast=True,
1110                 write_back=black.WriteBack.YES,
1111                 mode=DEFAULT_MODE,
1112                 report=report,
1113             )
1114             fsts.assert_called_once()
1115             report.done.assert_called_with(path, black.Changed.YES)
1116
1117     @pytest.mark.incompatible_with_mypyc
1118     def test_reformat_one_with_stdin_filename(self) -> None:
1119         with patch(
1120             "black.format_stdin_to_stdout",
1121             return_value=lambda *args, **kwargs: black.Changed.YES,
1122         ) as fsts:
1123             report = MagicMock()
1124             p = "foo.py"
1125             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1126             expected = Path(p)
1127             black.reformat_one(
1128                 path,
1129                 fast=True,
1130                 write_back=black.WriteBack.YES,
1131                 mode=DEFAULT_MODE,
1132                 report=report,
1133             )
1134             fsts.assert_called_once_with(
1135                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1136             )
1137             # __BLACK_STDIN_FILENAME__ should have been stripped
1138             report.done.assert_called_with(expected, black.Changed.YES)
1139
1140     @pytest.mark.incompatible_with_mypyc
1141     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1142         with patch(
1143             "black.format_stdin_to_stdout",
1144             return_value=lambda *args, **kwargs: black.Changed.YES,
1145         ) as fsts:
1146             report = MagicMock()
1147             p = "foo.pyi"
1148             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1149             expected = Path(p)
1150             black.reformat_one(
1151                 path,
1152                 fast=True,
1153                 write_back=black.WriteBack.YES,
1154                 mode=DEFAULT_MODE,
1155                 report=report,
1156             )
1157             fsts.assert_called_once_with(
1158                 fast=True,
1159                 write_back=black.WriteBack.YES,
1160                 mode=replace(DEFAULT_MODE, is_pyi=True),
1161             )
1162             # __BLACK_STDIN_FILENAME__ should have been stripped
1163             report.done.assert_called_with(expected, black.Changed.YES)
1164
1165     @pytest.mark.incompatible_with_mypyc
1166     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1167         with patch(
1168             "black.format_stdin_to_stdout",
1169             return_value=lambda *args, **kwargs: black.Changed.YES,
1170         ) as fsts:
1171             report = MagicMock()
1172             p = "foo.ipynb"
1173             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1174             expected = Path(p)
1175             black.reformat_one(
1176                 path,
1177                 fast=True,
1178                 write_back=black.WriteBack.YES,
1179                 mode=DEFAULT_MODE,
1180                 report=report,
1181             )
1182             fsts.assert_called_once_with(
1183                 fast=True,
1184                 write_back=black.WriteBack.YES,
1185                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1186             )
1187             # __BLACK_STDIN_FILENAME__ should have been stripped
1188             report.done.assert_called_with(expected, black.Changed.YES)
1189
1190     @pytest.mark.incompatible_with_mypyc
1191     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1192         with patch(
1193             "black.format_stdin_to_stdout",
1194             return_value=lambda *args, **kwargs: black.Changed.YES,
1195         ) as fsts:
1196             report = MagicMock()
1197             # Even with an existing file, since we are forcing stdin, black
1198             # should output to stdout and not modify the file inplace
1199             p = Path(str(THIS_DIR / "data/collections.py"))
1200             # Make sure is_file actually returns True
1201             self.assertTrue(p.is_file())
1202             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1203             expected = Path(p)
1204             black.reformat_one(
1205                 path,
1206                 fast=True,
1207                 write_back=black.WriteBack.YES,
1208                 mode=DEFAULT_MODE,
1209                 report=report,
1210             )
1211             fsts.assert_called_once()
1212             # __BLACK_STDIN_FILENAME__ should have been stripped
1213             report.done.assert_called_with(expected, black.Changed.YES)
1214
1215     def test_reformat_one_with_stdin_empty(self) -> None:
1216         output = io.StringIO()
1217         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1218             try:
1219                 black.format_stdin_to_stdout(
1220                     fast=True,
1221                     content="",
1222                     write_back=black.WriteBack.YES,
1223                     mode=DEFAULT_MODE,
1224                 )
1225             except io.UnsupportedOperation:
1226                 pass  # StringIO does not support detach
1227             assert output.getvalue() == ""
1228
1229     def test_invalid_cli_regex(self) -> None:
1230         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1231             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1232
1233     def test_required_version_matches_version(self) -> None:
1234         self.invokeBlack(
1235             ["--required-version", black.__version__, "-c", "0"],
1236             exit_code=0,
1237             ignore_config=True,
1238         )
1239
1240     def test_required_version_does_not_match_version(self) -> None:
1241         result = BlackRunner().invoke(
1242             black.main,
1243             ["--required-version", "20.99b", "-c", "0"],
1244         )
1245         self.assertEqual(result.exit_code, 1)
1246         self.assertIn("required version", result.stderr)
1247
1248     def test_preserves_line_endings(self) -> None:
1249         with TemporaryDirectory() as workspace:
1250             test_file = Path(workspace) / "test.py"
1251             for nl in ["\n", "\r\n"]:
1252                 contents = nl.join(["def f(  ):", "    pass"])
1253                 test_file.write_bytes(contents.encode())
1254                 ff(test_file, write_back=black.WriteBack.YES)
1255                 updated_contents: bytes = test_file.read_bytes()
1256                 self.assertIn(nl.encode(), updated_contents)
1257                 if nl == "\n":
1258                     self.assertNotIn(b"\r\n", updated_contents)
1259
1260     def test_preserves_line_endings_via_stdin(self) -> None:
1261         for nl in ["\n", "\r\n"]:
1262             contents = nl.join(["def f(  ):", "    pass"])
1263             runner = BlackRunner()
1264             result = runner.invoke(
1265                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1266             )
1267             self.assertEqual(result.exit_code, 0)
1268             output = result.stdout_bytes
1269             self.assertIn(nl.encode("utf8"), output)
1270             if nl == "\n":
1271                 self.assertNotIn(b"\r\n", output)
1272
1273     def test_assert_equivalent_different_asts(self) -> None:
1274         with self.assertRaises(AssertionError):
1275             black.assert_equivalent("{}", "None")
1276
1277     def test_shhh_click(self) -> None:
1278         try:
1279             from click import _unicodefun
1280         except ModuleNotFoundError:
1281             self.skipTest("Incompatible Click version")
1282         if not hasattr(_unicodefun, "_verify_python3_env"):
1283             self.skipTest("Incompatible Click version")
1284         # First, let's see if Click is crashing with a preferred ASCII charset.
1285         with patch("locale.getpreferredencoding") as gpe:
1286             gpe.return_value = "ASCII"
1287             with self.assertRaises(RuntimeError):
1288                 _unicodefun._verify_python3_env()  # type: ignore
1289         # Now, let's silence Click...
1290         black.patch_click()
1291         # ...and confirm it's silent.
1292         with patch("locale.getpreferredencoding") as gpe:
1293             gpe.return_value = "ASCII"
1294             try:
1295                 _unicodefun._verify_python3_env()  # type: ignore
1296             except RuntimeError as re:
1297                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1298
1299     def test_root_logger_not_used_directly(self) -> None:
1300         def fail(*args: Any, **kwargs: Any) -> None:
1301             self.fail("Record created with root logger")
1302
1303         with patch.multiple(
1304             logging.root,
1305             debug=fail,
1306             info=fail,
1307             warning=fail,
1308             error=fail,
1309             critical=fail,
1310             log=fail,
1311         ):
1312             ff(THIS_DIR / "util.py")
1313
1314     def test_invalid_config_return_code(self) -> None:
1315         tmp_file = Path(black.dump_to_file())
1316         try:
1317             tmp_config = Path(black.dump_to_file())
1318             tmp_config.unlink()
1319             args = ["--config", str(tmp_config), str(tmp_file)]
1320             self.invokeBlack(args, exit_code=2, ignore_config=False)
1321         finally:
1322             tmp_file.unlink()
1323
1324     def test_parse_pyproject_toml(self) -> None:
1325         test_toml_file = THIS_DIR / "test.toml"
1326         config = black.parse_pyproject_toml(str(test_toml_file))
1327         self.assertEqual(config["verbose"], 1)
1328         self.assertEqual(config["check"], "no")
1329         self.assertEqual(config["diff"], "y")
1330         self.assertEqual(config["color"], True)
1331         self.assertEqual(config["line_length"], 79)
1332         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1333         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1334         self.assertEqual(config["exclude"], r"\.pyi?$")
1335         self.assertEqual(config["include"], r"\.py?$")
1336
1337     def test_read_pyproject_toml(self) -> None:
1338         test_toml_file = THIS_DIR / "test.toml"
1339         fake_ctx = FakeContext()
1340         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1341         config = fake_ctx.default_map
1342         self.assertEqual(config["verbose"], "1")
1343         self.assertEqual(config["check"], "no")
1344         self.assertEqual(config["diff"], "y")
1345         self.assertEqual(config["color"], "True")
1346         self.assertEqual(config["line_length"], "79")
1347         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1348         self.assertEqual(config["exclude"], r"\.pyi?$")
1349         self.assertEqual(config["include"], r"\.py?$")
1350
1351     @pytest.mark.incompatible_with_mypyc
1352     def test_find_project_root(self) -> None:
1353         with TemporaryDirectory() as workspace:
1354             root = Path(workspace)
1355             test_dir = root / "test"
1356             test_dir.mkdir()
1357
1358             src_dir = root / "src"
1359             src_dir.mkdir()
1360
1361             root_pyproject = root / "pyproject.toml"
1362             root_pyproject.touch()
1363             src_pyproject = src_dir / "pyproject.toml"
1364             src_pyproject.touch()
1365             src_python = src_dir / "foo.py"
1366             src_python.touch()
1367
1368             self.assertEqual(
1369                 black.find_project_root((src_dir, test_dir)),
1370                 (root.resolve(), "pyproject.toml"),
1371             )
1372             self.assertEqual(
1373                 black.find_project_root((src_dir,)),
1374                 (src_dir.resolve(), "pyproject.toml"),
1375             )
1376             self.assertEqual(
1377                 black.find_project_root((src_python,)),
1378                 (src_dir.resolve(), "pyproject.toml"),
1379             )
1380
1381     @patch(
1382         "black.files.find_user_pyproject_toml",
1383         black.files.find_user_pyproject_toml.__wrapped__,
1384     )
1385     def test_find_user_pyproject_toml_linux(self) -> None:
1386         if system() == "Windows":
1387             return
1388
1389         # Test if XDG_CONFIG_HOME is checked
1390         with TemporaryDirectory() as workspace:
1391             tmp_user_config = Path(workspace) / "black"
1392             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1393                 self.assertEqual(
1394                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1395                 )
1396
1397         # Test fallback for XDG_CONFIG_HOME
1398         with patch.dict("os.environ"):
1399             os.environ.pop("XDG_CONFIG_HOME", None)
1400             fallback_user_config = Path("~/.config").expanduser() / "black"
1401             self.assertEqual(
1402                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1403             )
1404
1405     def test_find_user_pyproject_toml_windows(self) -> None:
1406         if system() != "Windows":
1407             return
1408
1409         user_config_path = Path.home() / ".black"
1410         self.assertEqual(
1411             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1412         )
1413
1414     def test_bpo_33660_workaround(self) -> None:
1415         if system() == "Windows":
1416             return
1417
1418         # https://bugs.python.org/issue33660
1419         root = Path("/")
1420         with change_directory(root):
1421             path = Path("workspace") / "project"
1422             report = black.Report(verbose=True)
1423             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1424             self.assertEqual(normalized_path, "workspace/project")
1425
1426     def test_newline_comment_interaction(self) -> None:
1427         source = "class A:\\\r\n# type: ignore\n pass\n"
1428         output = black.format_str(source, mode=DEFAULT_MODE)
1429         black.assert_stable(source, output, mode=DEFAULT_MODE)
1430
1431     def test_bpo_2142_workaround(self) -> None:
1432
1433         # https://bugs.python.org/issue2142
1434
1435         source, _ = read_data("missing_final_newline.py")
1436         # read_data adds a trailing newline
1437         source = source.rstrip()
1438         expected, _ = read_data("missing_final_newline.diff")
1439         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1440         diff_header = re.compile(
1441             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1442             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1443         )
1444         try:
1445             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1446             self.assertEqual(result.exit_code, 0)
1447         finally:
1448             os.unlink(tmp_file)
1449         actual = result.output
1450         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1451         self.assertEqual(actual, expected)
1452
1453     @staticmethod
1454     def compare_results(
1455         result: click.testing.Result, expected_value: str, expected_exit_code: int
1456     ) -> None:
1457         """Helper method to test the value and exit code of a click Result."""
1458         assert (
1459             result.output == expected_value
1460         ), "The output did not match the expected value."
1461         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1462
1463     def test_code_option(self) -> None:
1464         """Test the code option with no changes."""
1465         code = 'print("Hello world")\n'
1466         args = ["--code", code]
1467         result = CliRunner().invoke(black.main, args)
1468
1469         self.compare_results(result, code, 0)
1470
1471     def test_code_option_changed(self) -> None:
1472         """Test the code option when changes are required."""
1473         code = "print('hello world')"
1474         formatted = black.format_str(code, mode=DEFAULT_MODE)
1475
1476         args = ["--code", code]
1477         result = CliRunner().invoke(black.main, args)
1478
1479         self.compare_results(result, formatted, 0)
1480
1481     def test_code_option_check(self) -> None:
1482         """Test the code option when check is passed."""
1483         args = ["--check", "--code", 'print("Hello world")\n']
1484         result = CliRunner().invoke(black.main, args)
1485         self.compare_results(result, "", 0)
1486
1487     def test_code_option_check_changed(self) -> None:
1488         """Test the code option when changes are required, and check is passed."""
1489         args = ["--check", "--code", "print('hello world')"]
1490         result = CliRunner().invoke(black.main, args)
1491         self.compare_results(result, "", 1)
1492
1493     def test_code_option_diff(self) -> None:
1494         """Test the code option when diff is passed."""
1495         code = "print('hello world')"
1496         formatted = black.format_str(code, mode=DEFAULT_MODE)
1497         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1498
1499         args = ["--diff", "--code", code]
1500         result = CliRunner().invoke(black.main, args)
1501
1502         # Remove time from diff
1503         output = DIFF_TIME.sub("", result.output)
1504
1505         assert output == result_diff, "The output did not match the expected value."
1506         assert result.exit_code == 0, "The exit code is incorrect."
1507
1508     def test_code_option_color_diff(self) -> None:
1509         """Test the code option when color and diff are passed."""
1510         code = "print('hello world')"
1511         formatted = black.format_str(code, mode=DEFAULT_MODE)
1512
1513         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1514         result_diff = color_diff(result_diff)
1515
1516         args = ["--diff", "--color", "--code", code]
1517         result = CliRunner().invoke(black.main, args)
1518
1519         # Remove time from diff
1520         output = DIFF_TIME.sub("", result.output)
1521
1522         assert output == result_diff, "The output did not match the expected value."
1523         assert result.exit_code == 0, "The exit code is incorrect."
1524
1525     @pytest.mark.incompatible_with_mypyc
1526     def test_code_option_safe(self) -> None:
1527         """Test that the code option throws an error when the sanity checks fail."""
1528         # Patch black.assert_equivalent to ensure the sanity checks fail
1529         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1530             code = 'print("Hello world")'
1531             error_msg = f"{code}\nerror: cannot format <string>: \n"
1532
1533             args = ["--safe", "--code", code]
1534             result = CliRunner().invoke(black.main, args)
1535
1536             self.compare_results(result, error_msg, 123)
1537
1538     def test_code_option_fast(self) -> None:
1539         """Test that the code option ignores errors when the sanity checks fail."""
1540         # Patch black.assert_equivalent to ensure the sanity checks fail
1541         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1542             code = 'print("Hello world")'
1543             formatted = black.format_str(code, mode=DEFAULT_MODE)
1544
1545             args = ["--fast", "--code", code]
1546             result = CliRunner().invoke(black.main, args)
1547
1548             self.compare_results(result, formatted, 0)
1549
1550     @pytest.mark.incompatible_with_mypyc
1551     def test_code_option_config(self) -> None:
1552         """
1553         Test that the code option finds the pyproject.toml in the current directory.
1554         """
1555         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1556             args = ["--code", "print"]
1557             # This is the only directory known to contain a pyproject.toml
1558             with change_directory(PROJECT_ROOT):
1559                 CliRunner().invoke(black.main, args)
1560                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1561
1562             assert (
1563                 len(parse.mock_calls) >= 1
1564             ), "Expected config parse to be called with the current directory."
1565
1566             _, call_args, _ = parse.mock_calls[0]
1567             assert (
1568                 call_args[0].lower() == str(pyproject_path).lower()
1569             ), "Incorrect config loaded."
1570
1571     @pytest.mark.incompatible_with_mypyc
1572     def test_code_option_parent_config(self) -> None:
1573         """
1574         Test that the code option finds the pyproject.toml in the parent directory.
1575         """
1576         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1577             with change_directory(THIS_DIR):
1578                 args = ["--code", "print"]
1579                 CliRunner().invoke(black.main, args)
1580
1581                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1582                 assert (
1583                     len(parse.mock_calls) >= 1
1584                 ), "Expected config parse to be called with the current directory."
1585
1586                 _, call_args, _ = parse.mock_calls[0]
1587                 assert (
1588                     call_args[0].lower() == str(pyproject_path).lower()
1589                 ), "Incorrect config loaded."
1590
1591     def test_for_handled_unexpected_eof_error(self) -> None:
1592         """
1593         Test that an unexpected EOF SyntaxError is nicely presented.
1594         """
1595         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1596             black.lib2to3_parse("print(", {})
1597
1598         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1599
1600     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1601         with pytest.raises(AssertionError) as err:
1602             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1603
1604         err.match("--safe")
1605         # Unfortunately the SyntaxError message has changed in newer versions so we
1606         # can't match it directly.
1607         err.match("invalid character")
1608         err.match(r"\(<unknown>, line 1\)")
1609
1610
1611 class TestCaching:
1612     def test_get_cache_dir(
1613         self,
1614         tmp_path: Path,
1615         monkeypatch: pytest.MonkeyPatch,
1616     ) -> None:
1617         # Create multiple cache directories
1618         workspace1 = tmp_path / "ws1"
1619         workspace1.mkdir()
1620         workspace2 = tmp_path / "ws2"
1621         workspace2.mkdir()
1622
1623         # Force user_cache_dir to use the temporary directory for easier assertions
1624         patch_user_cache_dir = patch(
1625             target="black.cache.user_cache_dir",
1626             autospec=True,
1627             return_value=str(workspace1),
1628         )
1629
1630         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1631         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1632         with patch_user_cache_dir:
1633             assert get_cache_dir() == workspace1
1634
1635         # If it is set, use the path provided in the env var.
1636         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1637         assert get_cache_dir() == workspace2
1638
1639     def test_cache_broken_file(self) -> None:
1640         mode = DEFAULT_MODE
1641         with cache_dir() as workspace:
1642             cache_file = get_cache_file(mode)
1643             cache_file.write_text("this is not a pickle")
1644             assert black.read_cache(mode) == {}
1645             src = (workspace / "test.py").resolve()
1646             src.write_text("print('hello')")
1647             invokeBlack([str(src)])
1648             cache = black.read_cache(mode)
1649             assert str(src) in cache
1650
1651     def test_cache_single_file_already_cached(self) -> None:
1652         mode = DEFAULT_MODE
1653         with cache_dir() as workspace:
1654             src = (workspace / "test.py").resolve()
1655             src.write_text("print('hello')")
1656             black.write_cache({}, [src], mode)
1657             invokeBlack([str(src)])
1658             assert src.read_text() == "print('hello')"
1659
1660     @event_loop()
1661     def test_cache_multiple_files(self) -> None:
1662         mode = DEFAULT_MODE
1663         with cache_dir() as workspace, patch(
1664             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1665         ):
1666             one = (workspace / "one.py").resolve()
1667             with one.open("w") as fobj:
1668                 fobj.write("print('hello')")
1669             two = (workspace / "two.py").resolve()
1670             with two.open("w") as fobj:
1671                 fobj.write("print('hello')")
1672             black.write_cache({}, [one], mode)
1673             invokeBlack([str(workspace)])
1674             with one.open("r") as fobj:
1675                 assert fobj.read() == "print('hello')"
1676             with two.open("r") as fobj:
1677                 assert fobj.read() == 'print("hello")\n'
1678             cache = black.read_cache(mode)
1679             assert str(one) in cache
1680             assert str(two) in cache
1681
1682     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1683     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1684         mode = DEFAULT_MODE
1685         with cache_dir() as workspace:
1686             src = (workspace / "test.py").resolve()
1687             with src.open("w") as fobj:
1688                 fobj.write("print('hello')")
1689             with patch("black.read_cache") as read_cache, patch(
1690                 "black.write_cache"
1691             ) as write_cache:
1692                 cmd = [str(src), "--diff"]
1693                 if color:
1694                     cmd.append("--color")
1695                 invokeBlack(cmd)
1696                 cache_file = get_cache_file(mode)
1697                 assert cache_file.exists() is False
1698                 write_cache.assert_not_called()
1699                 read_cache.assert_not_called()
1700
1701     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1702     @event_loop()
1703     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1704         with cache_dir() as workspace:
1705             for tag in range(0, 4):
1706                 src = (workspace / f"test{tag}.py").resolve()
1707                 with src.open("w") as fobj:
1708                     fobj.write("print('hello')")
1709             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1710                 cmd = ["--diff", str(workspace)]
1711                 if color:
1712                     cmd.append("--color")
1713                 invokeBlack(cmd, exit_code=0)
1714                 # this isn't quite doing what we want, but if it _isn't_
1715                 # called then we cannot be using the lock it provides
1716                 mgr.assert_called()
1717
1718     def test_no_cache_when_stdin(self) -> None:
1719         mode = DEFAULT_MODE
1720         with cache_dir():
1721             result = CliRunner().invoke(
1722                 black.main, ["-"], input=BytesIO(b"print('hello')")
1723             )
1724             assert not result.exit_code
1725             cache_file = get_cache_file(mode)
1726             assert not cache_file.exists()
1727
1728     def test_read_cache_no_cachefile(self) -> None:
1729         mode = DEFAULT_MODE
1730         with cache_dir():
1731             assert black.read_cache(mode) == {}
1732
1733     def test_write_cache_read_cache(self) -> None:
1734         mode = DEFAULT_MODE
1735         with cache_dir() as workspace:
1736             src = (workspace / "test.py").resolve()
1737             src.touch()
1738             black.write_cache({}, [src], mode)
1739             cache = black.read_cache(mode)
1740             assert str(src) in cache
1741             assert cache[str(src)] == black.get_cache_info(src)
1742
1743     def test_filter_cached(self) -> None:
1744         with TemporaryDirectory() as workspace:
1745             path = Path(workspace)
1746             uncached = (path / "uncached").resolve()
1747             cached = (path / "cached").resolve()
1748             cached_but_changed = (path / "changed").resolve()
1749             uncached.touch()
1750             cached.touch()
1751             cached_but_changed.touch()
1752             cache = {
1753                 str(cached): black.get_cache_info(cached),
1754                 str(cached_but_changed): (0.0, 0),
1755             }
1756             todo, done = black.filter_cached(
1757                 cache, {uncached, cached, cached_but_changed}
1758             )
1759             assert todo == {uncached, cached_but_changed}
1760             assert done == {cached}
1761
1762     def test_write_cache_creates_directory_if_needed(self) -> None:
1763         mode = DEFAULT_MODE
1764         with cache_dir(exists=False) as workspace:
1765             assert not workspace.exists()
1766             black.write_cache({}, [], mode)
1767             assert workspace.exists()
1768
1769     @event_loop()
1770     def test_failed_formatting_does_not_get_cached(self) -> None:
1771         mode = DEFAULT_MODE
1772         with cache_dir() as workspace, patch(
1773             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1774         ):
1775             failing = (workspace / "failing.py").resolve()
1776             with failing.open("w") as fobj:
1777                 fobj.write("not actually python")
1778             clean = (workspace / "clean.py").resolve()
1779             with clean.open("w") as fobj:
1780                 fobj.write('print("hello")\n')
1781             invokeBlack([str(workspace)], exit_code=123)
1782             cache = black.read_cache(mode)
1783             assert str(failing) not in cache
1784             assert str(clean) in cache
1785
1786     def test_write_cache_write_fail(self) -> None:
1787         mode = DEFAULT_MODE
1788         with cache_dir(), patch.object(Path, "open") as mock:
1789             mock.side_effect = OSError
1790             black.write_cache({}, [], mode)
1791
1792     def test_read_cache_line_lengths(self) -> None:
1793         mode = DEFAULT_MODE
1794         short_mode = replace(DEFAULT_MODE, line_length=1)
1795         with cache_dir() as workspace:
1796             path = (workspace / "file.py").resolve()
1797             path.touch()
1798             black.write_cache({}, [path], mode)
1799             one = black.read_cache(mode)
1800             assert str(path) in one
1801             two = black.read_cache(short_mode)
1802             assert str(path) not in two
1803
1804
1805 def assert_collected_sources(
1806     src: Sequence[Union[str, Path]],
1807     expected: Sequence[Union[str, Path]],
1808     *,
1809     ctx: Optional[FakeContext] = None,
1810     exclude: Optional[str] = None,
1811     include: Optional[str] = None,
1812     extend_exclude: Optional[str] = None,
1813     force_exclude: Optional[str] = None,
1814     stdin_filename: Optional[str] = None,
1815 ) -> None:
1816     gs_src = tuple(str(Path(s)) for s in src)
1817     gs_expected = [Path(s) for s in expected]
1818     gs_exclude = None if exclude is None else compile_pattern(exclude)
1819     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1820     gs_extend_exclude = (
1821         None if extend_exclude is None else compile_pattern(extend_exclude)
1822     )
1823     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1824     collected = black.get_sources(
1825         ctx=ctx or FakeContext(),
1826         src=gs_src,
1827         quiet=False,
1828         verbose=False,
1829         include=gs_include,
1830         exclude=gs_exclude,
1831         extend_exclude=gs_extend_exclude,
1832         force_exclude=gs_force_exclude,
1833         report=black.Report(),
1834         stdin_filename=stdin_filename,
1835     )
1836     assert sorted(collected) == sorted(gs_expected)
1837
1838
1839 class TestFileCollection:
1840     def test_include_exclude(self) -> None:
1841         path = THIS_DIR / "data" / "include_exclude_tests"
1842         src = [path]
1843         expected = [
1844             Path(path / "b/dont_exclude/a.py"),
1845             Path(path / "b/dont_exclude/a.pyi"),
1846         ]
1847         assert_collected_sources(
1848             src,
1849             expected,
1850             include=r"\.pyi?$",
1851             exclude=r"/exclude/|/\.definitely_exclude/",
1852         )
1853
1854     def test_gitignore_used_as_default(self) -> None:
1855         base = Path(DATA_DIR / "include_exclude_tests")
1856         expected = [
1857             base / "b/.definitely_exclude/a.py",
1858             base / "b/.definitely_exclude/a.pyi",
1859         ]
1860         src = [base / "b/"]
1861         ctx = FakeContext()
1862         ctx.obj["root"] = base
1863         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1864
1865     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1866     def test_exclude_for_issue_1572(self) -> None:
1867         # Exclude shouldn't touch files that were explicitly given to Black through the
1868         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1869         # https://github.com/psf/black/issues/1572
1870         path = DATA_DIR / "include_exclude_tests"
1871         src = [path / "b/exclude/a.py"]
1872         expected = [path / "b/exclude/a.py"]
1873         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1874
1875     def test_gitignore_exclude(self) -> None:
1876         path = THIS_DIR / "data" / "include_exclude_tests"
1877         include = re.compile(r"\.pyi?$")
1878         exclude = re.compile(r"")
1879         report = black.Report()
1880         gitignore = PathSpec.from_lines(
1881             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1882         )
1883         sources: List[Path] = []
1884         expected = [
1885             Path(path / "b/dont_exclude/a.py"),
1886             Path(path / "b/dont_exclude/a.pyi"),
1887         ]
1888         this_abs = THIS_DIR.resolve()
1889         sources.extend(
1890             black.gen_python_files(
1891                 path.iterdir(),
1892                 this_abs,
1893                 include,
1894                 exclude,
1895                 None,
1896                 None,
1897                 report,
1898                 gitignore,
1899                 verbose=False,
1900                 quiet=False,
1901             )
1902         )
1903         assert sorted(expected) == sorted(sources)
1904
1905     def test_nested_gitignore(self) -> None:
1906         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1907         include = re.compile(r"\.pyi?$")
1908         exclude = re.compile(r"")
1909         root_gitignore = black.files.get_gitignore(path)
1910         report = black.Report()
1911         expected: List[Path] = [
1912             Path(path / "x.py"),
1913             Path(path / "root/b.py"),
1914             Path(path / "root/c.py"),
1915             Path(path / "root/child/c.py"),
1916         ]
1917         this_abs = THIS_DIR.resolve()
1918         sources = list(
1919             black.gen_python_files(
1920                 path.iterdir(),
1921                 this_abs,
1922                 include,
1923                 exclude,
1924                 None,
1925                 None,
1926                 report,
1927                 root_gitignore,
1928                 verbose=False,
1929                 quiet=False,
1930             )
1931         )
1932         assert sorted(expected) == sorted(sources)
1933
1934     def test_invalid_gitignore(self) -> None:
1935         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1936         empty_config = path / "pyproject.toml"
1937         result = BlackRunner().invoke(
1938             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1939         )
1940         assert result.exit_code == 1
1941         assert result.stderr_bytes is not None
1942
1943         gitignore = path / ".gitignore"
1944         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1945
1946     def test_invalid_nested_gitignore(self) -> None:
1947         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1948         empty_config = path / "pyproject.toml"
1949         result = BlackRunner().invoke(
1950             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1951         )
1952         assert result.exit_code == 1
1953         assert result.stderr_bytes is not None
1954
1955         gitignore = path / "a" / ".gitignore"
1956         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1957
1958     def test_empty_include(self) -> None:
1959         path = DATA_DIR / "include_exclude_tests"
1960         src = [path]
1961         expected = [
1962             Path(path / "b/exclude/a.pie"),
1963             Path(path / "b/exclude/a.py"),
1964             Path(path / "b/exclude/a.pyi"),
1965             Path(path / "b/dont_exclude/a.pie"),
1966             Path(path / "b/dont_exclude/a.py"),
1967             Path(path / "b/dont_exclude/a.pyi"),
1968             Path(path / "b/.definitely_exclude/a.pie"),
1969             Path(path / "b/.definitely_exclude/a.py"),
1970             Path(path / "b/.definitely_exclude/a.pyi"),
1971             Path(path / ".gitignore"),
1972             Path(path / "pyproject.toml"),
1973         ]
1974         # Setting exclude explicitly to an empty string to block .gitignore usage.
1975         assert_collected_sources(src, expected, include="", exclude="")
1976
1977     def test_extend_exclude(self) -> None:
1978         path = DATA_DIR / "include_exclude_tests"
1979         src = [path]
1980         expected = [
1981             Path(path / "b/exclude/a.py"),
1982             Path(path / "b/dont_exclude/a.py"),
1983         ]
1984         assert_collected_sources(
1985             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1986         )
1987
1988     @pytest.mark.incompatible_with_mypyc
1989     def test_symlink_out_of_root_directory(self) -> None:
1990         path = MagicMock()
1991         root = THIS_DIR.resolve()
1992         child = MagicMock()
1993         include = re.compile(black.DEFAULT_INCLUDES)
1994         exclude = re.compile(black.DEFAULT_EXCLUDES)
1995         report = black.Report()
1996         gitignore = PathSpec.from_lines("gitwildmatch", [])
1997         # `child` should behave like a symlink which resolved path is clearly
1998         # outside of the `root` directory.
1999         path.iterdir.return_value = [child]
2000         child.resolve.return_value = Path("/a/b/c")
2001         child.as_posix.return_value = "/a/b/c"
2002         child.is_symlink.return_value = True
2003         try:
2004             list(
2005                 black.gen_python_files(
2006                     path.iterdir(),
2007                     root,
2008                     include,
2009                     exclude,
2010                     None,
2011                     None,
2012                     report,
2013                     gitignore,
2014                     verbose=False,
2015                     quiet=False,
2016                 )
2017             )
2018         except ValueError as ve:
2019             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2020         path.iterdir.assert_called_once()
2021         child.resolve.assert_called_once()
2022         child.is_symlink.assert_called_once()
2023         # `child` should behave like a strange file which resolved path is clearly
2024         # outside of the `root` directory.
2025         child.is_symlink.return_value = False
2026         with pytest.raises(ValueError):
2027             list(
2028                 black.gen_python_files(
2029                     path.iterdir(),
2030                     root,
2031                     include,
2032                     exclude,
2033                     None,
2034                     None,
2035                     report,
2036                     gitignore,
2037                     verbose=False,
2038                     quiet=False,
2039                 )
2040             )
2041         path.iterdir.assert_called()
2042         assert path.iterdir.call_count == 2
2043         child.resolve.assert_called()
2044         assert child.resolve.call_count == 2
2045         child.is_symlink.assert_called()
2046         assert child.is_symlink.call_count == 2
2047
2048     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2049     def test_get_sources_with_stdin(self) -> None:
2050         src = ["-"]
2051         expected = ["-"]
2052         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2053
2054     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2055     def test_get_sources_with_stdin_filename(self) -> None:
2056         src = ["-"]
2057         stdin_filename = str(THIS_DIR / "data/collections.py")
2058         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2059         assert_collected_sources(
2060             src,
2061             expected,
2062             exclude=r"/exclude/a\.py",
2063             stdin_filename=stdin_filename,
2064         )
2065
2066     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2067     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2068         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2069         # file being passed directly. This is the same as
2070         # test_exclude_for_issue_1572
2071         path = DATA_DIR / "include_exclude_tests"
2072         src = ["-"]
2073         stdin_filename = str(path / "b/exclude/a.py")
2074         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2075         assert_collected_sources(
2076             src,
2077             expected,
2078             exclude=r"/exclude/|a\.py",
2079             stdin_filename=stdin_filename,
2080         )
2081
2082     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2083     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2084         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2085         # file being passed directly. This is the same as
2086         # test_exclude_for_issue_1572
2087         src = ["-"]
2088         path = THIS_DIR / "data" / "include_exclude_tests"
2089         stdin_filename = str(path / "b/exclude/a.py")
2090         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2091         assert_collected_sources(
2092             src,
2093             expected,
2094             extend_exclude=r"/exclude/|a\.py",
2095             stdin_filename=stdin_filename,
2096         )
2097
2098     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2099     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2100         # Force exclude should exclude the file when passing it through
2101         # stdin_filename
2102         path = THIS_DIR / "data" / "include_exclude_tests"
2103         stdin_filename = str(path / "b/exclude/a.py")
2104         assert_collected_sources(
2105             src=["-"],
2106             expected=[],
2107             force_exclude=r"/exclude/|a\.py",
2108             stdin_filename=stdin_filename,
2109         )
2110
2111
2112 try:
2113     with open(black.__file__, "r", encoding="utf-8") as _bf:
2114         black_source_lines = _bf.readlines()
2115 except UnicodeDecodeError:
2116     if not black.COMPILED:
2117         raise
2118
2119
2120 def tracefunc(
2121     frame: types.FrameType, event: str, arg: Any
2122 ) -> Callable[[types.FrameType, str, Any], Any]:
2123     """Show function calls `from black/__init__.py` as they happen.
2124
2125     Register this with `sys.settrace()` in a test you're debugging.
2126     """
2127     if event != "call":
2128         return tracefunc
2129
2130     stack = len(inspect.stack()) - 19
2131     stack *= 2
2132     filename = frame.f_code.co_filename
2133     lineno = frame.f_lineno
2134     func_sig_lineno = lineno - 1
2135     funcname = black_source_lines[func_sig_lineno].strip()
2136     while funcname.startswith("@"):
2137         func_sig_lineno += 1
2138         funcname = black_source_lines[func_sig_lineno].strip()
2139     if "black/__init__.py" in filename:
2140         print(f"{' ' * stack}{lineno}:{funcname}")
2141     return tracefunc