]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Set `click` lower bound to `8.0.0` (#2791)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103         # Dummy root, since most of the tests don't care about it
104         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
105
106
107 class FakeParameter(click.Parameter):
108     """A fake click Parameter for when calling functions that need it."""
109
110     def __init__(self) -> None:
111         pass
112
113
114 class BlackRunner(CliRunner):
115     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
116
117     def __init__(self) -> None:
118         super().__init__(mix_stderr=False)
119
120
121 def invokeBlack(
122     args: List[str], exit_code: int = 0, ignore_config: bool = True
123 ) -> None:
124     runner = BlackRunner()
125     if ignore_config:
126         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
127     result = runner.invoke(black.main, args, catch_exceptions=False)
128     assert result.stdout_bytes is not None
129     assert result.stderr_bytes is not None
130     msg = (
131         f"Failed with args: {args}\n"
132         f"stdout: {result.stdout_bytes.decode()!r}\n"
133         f"stderr: {result.stderr_bytes.decode()!r}\n"
134         f"exception: {result.exception}"
135     )
136     assert result.exit_code == exit_code, msg
137
138
139 class BlackTestCase(BlackBaseTestCase):
140     invokeBlack = staticmethod(invokeBlack)
141
142     def test_empty_ff(self) -> None:
143         expected = ""
144         tmp_file = Path(black.dump_to_file())
145         try:
146             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
147             with open(tmp_file, encoding="utf8") as f:
148                 actual = f.read()
149         finally:
150             os.unlink(tmp_file)
151         self.assertFormatEqual(expected, actual)
152
153     def test_experimental_string_processing_warns(self) -> None:
154         self.assertWarns(
155             black.mode.Deprecated, black.Mode, experimental_string_processing=True
156         )
157
158     def test_piping(self) -> None:
159         source, expected = read_data("src/black/__init__", data=False)
160         result = BlackRunner().invoke(
161             black.main,
162             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
163             input=BytesIO(source.encode("utf8")),
164         )
165         self.assertEqual(result.exit_code, 0)
166         self.assertFormatEqual(expected, result.output)
167         if source != result.output:
168             black.assert_equivalent(source, result.output)
169             black.assert_stable(source, result.output, DEFAULT_MODE)
170
171     def test_piping_diff(self) -> None:
172         diff_header = re.compile(
173             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
174             r"\+\d\d\d\d"
175         )
176         source, _ = read_data("expression.py")
177         expected, _ = read_data("expression.diff")
178         config = THIS_DIR / "data" / "empty_pyproject.toml"
179         args = [
180             "-",
181             "--fast",
182             f"--line-length={black.DEFAULT_LINE_LENGTH}",
183             "--diff",
184             f"--config={config}",
185         ]
186         result = BlackRunner().invoke(
187             black.main, args, input=BytesIO(source.encode("utf8"))
188         )
189         self.assertEqual(result.exit_code, 0)
190         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
191         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
192         self.assertEqual(expected, actual)
193
194     def test_piping_diff_with_color(self) -> None:
195         source, _ = read_data("expression.py")
196         config = THIS_DIR / "data" / "empty_pyproject.toml"
197         args = [
198             "-",
199             "--fast",
200             f"--line-length={black.DEFAULT_LINE_LENGTH}",
201             "--diff",
202             "--color",
203             f"--config={config}",
204         ]
205         result = BlackRunner().invoke(
206             black.main, args, input=BytesIO(source.encode("utf8"))
207         )
208         actual = result.output
209         # Again, the contents are checked in a different test, so only look for colors.
210         self.assertIn("\033[1m", actual)
211         self.assertIn("\033[36m", actual)
212         self.assertIn("\033[32m", actual)
213         self.assertIn("\033[31m", actual)
214         self.assertIn("\033[0m", actual)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def _test_wip(self) -> None:
218         source, expected = read_data("wip")
219         sys.settrace(tracefunc)
220         mode = replace(
221             DEFAULT_MODE,
222             experimental_string_processing=False,
223             target_versions={black.TargetVersion.PY38},
224         )
225         actual = fs(source, mode=mode)
226         sys.settrace(None)
227         self.assertFormatEqual(expected, actual)
228         black.assert_equivalent(source, actual)
229         black.assert_stable(source, actual, black.FileMode())
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability1(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens1")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability2(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens2")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @unittest.expectedFailure
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_trailing_comma_optional_parens_stability3(self) -> None:
248         source, _expected = read_data("trailing_comma_optional_parens3")
249         actual = fs(source)
250         black.assert_stable(source, actual, DEFAULT_MODE)
251
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens1")
255         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens2")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens3")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     def test_pep_572_version_detection(self) -> None:
271         source, _ = read_data("pep_572")
272         root = black.lib2to3_parse(source)
273         features = black.get_features_used(root)
274         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
275         versions = black.detect_target_versions(root)
276         self.assertIn(black.TargetVersion.PY38, versions)
277
278     def test_expression_ff(self) -> None:
279         source, expected = read_data("expression")
280         tmp_file = Path(black.dump_to_file(source))
281         try:
282             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
283             with open(tmp_file, encoding="utf8") as f:
284                 actual = f.read()
285         finally:
286             os.unlink(tmp_file)
287         self.assertFormatEqual(expected, actual)
288         with patch("black.dump_to_file", dump_to_stderr):
289             black.assert_equivalent(source, actual)
290             black.assert_stable(source, actual, DEFAULT_MODE)
291
292     def test_expression_diff(self) -> None:
293         source, _ = read_data("expression.py")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         expected, _ = read_data("expression.diff")
296         tmp_file = Path(black.dump_to_file(source))
297         diff_header = re.compile(
298             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
299             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
300         )
301         try:
302             result = BlackRunner().invoke(
303                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
304             )
305             self.assertEqual(result.exit_code, 0)
306         finally:
307             os.unlink(tmp_file)
308         actual = result.output
309         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
310         if expected != actual:
311             dump = black.dump_to_file(actual)
312             msg = (
313                 "Expected diff isn't equal to the actual. If you made changes to"
314                 " expression.py and this is an anticipated difference, overwrite"
315                 f" tests/data/expression.diff with {dump}"
316             )
317             self.assertEqual(expected, actual, msg)
318
319     def test_expression_diff_with_color(self) -> None:
320         source, _ = read_data("expression.py")
321         config = THIS_DIR / "data" / "empty_pyproject.toml"
322         expected, _ = read_data("expression.diff")
323         tmp_file = Path(black.dump_to_file(source))
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
327             )
328         finally:
329             os.unlink(tmp_file)
330         actual = result.output
331         # We check the contents of the diff in `test_expression_diff`. All
332         # we need to check here is that color codes exist in the result.
333         self.assertIn("\033[1m", actual)
334         self.assertIn("\033[36m", actual)
335         self.assertIn("\033[32m", actual)
336         self.assertIn("\033[31m", actual)
337         self.assertIn("\033[0m", actual)
338
339     def test_detect_pos_only_arguments(self) -> None:
340         source, _ = read_data("pep_570")
341         root = black.lib2to3_parse(source)
342         features = black.get_features_used(root)
343         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
344         versions = black.detect_target_versions(root)
345         self.assertIn(black.TargetVersion.PY38, versions)
346
347     @patch("black.dump_to_file", dump_to_stderr)
348     def test_string_quotes(self) -> None:
349         source, expected = read_data("string_quotes")
350         mode = black.Mode(preview=True)
351         assert_format(source, expected, mode)
352         mode = replace(mode, string_normalization=False)
353         not_normalized = fs(source, mode=mode)
354         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
355         black.assert_equivalent(source, not_normalized)
356         black.assert_stable(source, not_normalized, mode=mode)
357
358     def test_skip_magic_trailing_comma(self) -> None:
359         source, _ = read_data("expression.py")
360         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
361         tmp_file = Path(black.dump_to_file(source))
362         diff_header = re.compile(
363             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
364             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
365         )
366         try:
367             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
368             self.assertEqual(result.exit_code, 0)
369         finally:
370             os.unlink(tmp_file)
371         actual = result.output
372         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
373         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
374         if expected != actual:
375             dump = black.dump_to_file(actual)
376             msg = (
377                 "Expected diff isn't equal to the actual. If you made changes to"
378                 " expression.py and this is an anticipated difference, overwrite"
379                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
380             )
381             self.assertEqual(expected, actual, msg)
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_async_as_identifier(self) -> None:
385         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
386         source, expected = read_data("async_as_identifier")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         major, minor = sys.version_info[:2]
390         if major < 3 or (major <= 3 and minor < 7):
391             black.assert_equivalent(source, actual)
392         black.assert_stable(source, actual, DEFAULT_MODE)
393         # ensure black can parse this when the target is 3.6
394         self.invokeBlack([str(source_path), "--target-version", "py36"])
395         # but not on 3.7, because async/await is no longer an identifier
396         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_python37(self) -> None:
400         source_path = (THIS_DIR / "data" / "python37.py").resolve()
401         source, expected = read_data("python37")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         major, minor = sys.version_info[:2]
405         if major > 3 or (major == 3 and minor >= 7):
406             black.assert_equivalent(source, actual)
407         black.assert_stable(source, actual, DEFAULT_MODE)
408         # ensure black can parse this when the target is 3.7
409         self.invokeBlack([str(source_path), "--target-version", "py37"])
410         # but not on 3.6, because we use async as a reserved keyword
411         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
412
413     def test_tab_comment_indentation(self) -> None:
414         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
415         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
416         self.assertFormatEqual(contents_spc, fs(contents_spc))
417         self.assertFormatEqual(contents_spc, fs(contents_tab))
418
419         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
420         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
421         self.assertFormatEqual(contents_spc, fs(contents_spc))
422         self.assertFormatEqual(contents_spc, fs(contents_tab))
423
424         # mixed tabs and spaces (valid Python 2 code)
425         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
426         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
427         self.assertFormatEqual(contents_spc, fs(contents_spc))
428         self.assertFormatEqual(contents_spc, fs(contents_tab))
429
430         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
431         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
432         self.assertFormatEqual(contents_spc, fs(contents_spc))
433         self.assertFormatEqual(contents_spc, fs(contents_tab))
434
435     def test_report_verbose(self) -> None:
436         report = Report(verbose=True)
437         out_lines = []
438         err_lines = []
439
440         def out(msg: str, **kwargs: Any) -> None:
441             out_lines.append(msg)
442
443         def err(msg: str, **kwargs: Any) -> None:
444             err_lines.append(msg)
445
446         with patch("black.output._out", out), patch("black.output._err", err):
447             report.done(Path("f1"), black.Changed.NO)
448             self.assertEqual(len(out_lines), 1)
449             self.assertEqual(len(err_lines), 0)
450             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
451             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
452             self.assertEqual(report.return_code, 0)
453             report.done(Path("f2"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 2)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(out_lines[-1], "reformatted f2")
457             self.assertEqual(
458                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
459             )
460             report.done(Path("f3"), black.Changed.CACHED)
461             self.assertEqual(len(out_lines), 3)
462             self.assertEqual(len(err_lines), 0)
463             self.assertEqual(
464                 out_lines[-1], "f3 wasn't modified on disk since last run."
465             )
466             self.assertEqual(
467                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
468             )
469             self.assertEqual(report.return_code, 0)
470             report.check = True
471             self.assertEqual(report.return_code, 1)
472             report.check = False
473             report.failed(Path("e1"), "boom")
474             self.assertEqual(len(out_lines), 3)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f3"), black.Changed.YES)
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 1)
486             self.assertEqual(out_lines[-1], "reformatted f3")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.failed(Path("e2"), "boom")
494             self.assertEqual(len(out_lines), 4)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.path_ignored(Path("wat"), "no match")
504             self.assertEqual(len(out_lines), 5)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "wat ignored: no match")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.done(Path("f4"), black.Changed.NO)
514             self.assertEqual(len(out_lines), 6)
515             self.assertEqual(len(err_lines), 2)
516             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
520                 " reformat.",
521             )
522             self.assertEqual(report.return_code, 123)
523             report.check = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529             report.check = False
530             report.diff = True
531             self.assertEqual(
532                 unstyle(str(report)),
533                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
534                 " would fail to reformat.",
535             )
536
537     def test_report_quiet(self) -> None:
538         report = Report(quiet=True)
539         out_lines = []
540         err_lines = []
541
542         def out(msg: str, **kwargs: Any) -> None:
543             out_lines.append(msg)
544
545         def err(msg: str, **kwargs: Any) -> None:
546             err_lines.append(msg)
547
548         with patch("black.output._out", out), patch("black.output._err", err):
549             report.done(Path("f1"), black.Changed.NO)
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
553             self.assertEqual(report.return_code, 0)
554             report.done(Path("f2"), black.Changed.YES)
555             self.assertEqual(len(out_lines), 0)
556             self.assertEqual(len(err_lines), 0)
557             self.assertEqual(
558                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
559             )
560             report.done(Path("f3"), black.Changed.CACHED)
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(
564                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
565             )
566             self.assertEqual(report.return_code, 0)
567             report.check = True
568             self.assertEqual(report.return_code, 1)
569             report.check = False
570             report.failed(Path("e1"), "boom")
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
574             self.assertEqual(
575                 unstyle(str(report)),
576                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
577                 " reformat.",
578             )
579             self.assertEqual(report.return_code, 123)
580             report.done(Path("f3"), black.Changed.YES)
581             self.assertEqual(len(out_lines), 0)
582             self.assertEqual(len(err_lines), 1)
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.failed(Path("e2"), "boom")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
596                 " reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.path_ignored(Path("wat"), "no match")
600             self.assertEqual(len(out_lines), 0)
601             self.assertEqual(len(err_lines), 2)
602             self.assertEqual(
603                 unstyle(str(report)),
604                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
605                 " reformat.",
606             )
607             self.assertEqual(report.return_code, 123)
608             report.done(Path("f4"), black.Changed.NO)
609             self.assertEqual(len(out_lines), 0)
610             self.assertEqual(len(err_lines), 2)
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
614                 " reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.check = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623             report.check = False
624             report.diff = True
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
628                 " would fail to reformat.",
629             )
630
631     def test_report_normal(self) -> None:
632         report = black.Report()
633         out_lines = []
634         err_lines = []
635
636         def out(msg: str, **kwargs: Any) -> None:
637             out_lines.append(msg)
638
639         def err(msg: str, **kwargs: Any) -> None:
640             err_lines.append(msg)
641
642         with patch("black.output._out", out), patch("black.output._err", err):
643             report.done(Path("f1"), black.Changed.NO)
644             self.assertEqual(len(out_lines), 0)
645             self.assertEqual(len(err_lines), 0)
646             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
647             self.assertEqual(report.return_code, 0)
648             report.done(Path("f2"), black.Changed.YES)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
654             )
655             report.done(Path("f3"), black.Changed.CACHED)
656             self.assertEqual(len(out_lines), 1)
657             self.assertEqual(len(err_lines), 0)
658             self.assertEqual(out_lines[-1], "reformatted f2")
659             self.assertEqual(
660                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
661             )
662             self.assertEqual(report.return_code, 0)
663             report.check = True
664             self.assertEqual(report.return_code, 1)
665             report.check = False
666             report.failed(Path("e1"), "boom")
667             self.assertEqual(len(out_lines), 1)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.done(Path("f3"), black.Changed.YES)
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 1)
679             self.assertEqual(out_lines[-1], "reformatted f3")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.failed(Path("e2"), "boom")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
693                 " reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.path_ignored(Path("wat"), "no match")
697             self.assertEqual(len(out_lines), 2)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
702                 " reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.done(Path("f4"), black.Changed.NO)
706             self.assertEqual(len(out_lines), 2)
707             self.assertEqual(len(err_lines), 2)
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
711                 " reformat.",
712             )
713             self.assertEqual(report.return_code, 123)
714             report.check = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720             report.check = False
721             report.diff = True
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
725                 " would fail to reformat.",
726             )
727
728     def test_lib2to3_parse(self) -> None:
729         with self.assertRaises(black.InvalidInput):
730             black.lib2to3_parse("invalid syntax")
731
732         straddling = "x + y"
733         black.lib2to3_parse(straddling)
734         black.lib2to3_parse(straddling, {TargetVersion.PY36})
735
736         py2_only = "print x"
737         with self.assertRaises(black.InvalidInput):
738             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
739
740         py3_only = "exec(x, end=y)"
741         black.lib2to3_parse(py3_only)
742         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
743
744     def test_get_features_used_decorator(self) -> None:
745         # Test the feature detection of new decorator syntax
746         # since this makes some test cases of test_get_features_used()
747         # fails if it fails, this is tested first so that a useful case
748         # is identified
749         simples, relaxed = read_data("decorators")
750         # skip explanation comments at the top of the file
751         for simple_test in simples.split("##")[1:]:
752             node = black.lib2to3_parse(simple_test)
753             decorator = str(node.children[0].children[0]).strip()
754             self.assertNotIn(
755                 Feature.RELAXED_DECORATORS,
756                 black.get_features_used(node),
757                 msg=(
758                     f"decorator '{decorator}' follows python<=3.8 syntax"
759                     "but is detected as 3.9+"
760                     # f"The full node is\n{node!r}"
761                 ),
762             )
763         # skip the '# output' comment at the top of the output part
764         for relaxed_test in relaxed.split("##")[1:]:
765             node = black.lib2to3_parse(relaxed_test)
766             decorator = str(node.children[0].children[0]).strip()
767             self.assertIn(
768                 Feature.RELAXED_DECORATORS,
769                 black.get_features_used(node),
770                 msg=(
771                     f"decorator '{decorator}' uses python3.9+ syntax"
772                     "but is detected as python<=3.8"
773                     # f"The full node is\n{node!r}"
774                 ),
775             )
776
777     def test_get_features_used(self) -> None:
778         node = black.lib2to3_parse("def f(*, arg): ...\n")
779         self.assertEqual(black.get_features_used(node), set())
780         node = black.lib2to3_parse("def f(*, arg,): ...\n")
781         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
782         node = black.lib2to3_parse("f(*arg,)\n")
783         self.assertEqual(
784             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
785         )
786         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
787         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
788         node = black.lib2to3_parse("123_456\n")
789         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
790         node = black.lib2to3_parse("123456\n")
791         self.assertEqual(black.get_features_used(node), set())
792         source, expected = read_data("function")
793         node = black.lib2to3_parse(source)
794         expected_features = {
795             Feature.TRAILING_COMMA_IN_CALL,
796             Feature.TRAILING_COMMA_IN_DEF,
797             Feature.F_STRINGS,
798         }
799         self.assertEqual(black.get_features_used(node), expected_features)
800         node = black.lib2to3_parse(expected)
801         self.assertEqual(black.get_features_used(node), expected_features)
802         source, expected = read_data("expression")
803         node = black.lib2to3_parse(source)
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse(expected)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse("lambda a, /, b: ...")
808         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
809         node = black.lib2to3_parse("def fn(a, /, b): ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(): yield a, b")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("def fn(): return a, b")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("def fn(): yield *b, c")
816         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
817         node = black.lib2to3_parse("def fn(): return a, *b, c")
818         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
819         node = black.lib2to3_parse("x = a, *b, c")
820         self.assertEqual(black.get_features_used(node), set())
821         node = black.lib2to3_parse("x: Any = regular")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("x: Any = (regular, regular)")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
826         self.assertEqual(black.get_features_used(node), set())
827         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
828         self.assertEqual(
829             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
830         )
831
832     def test_get_features_used_for_future_flags(self) -> None:
833         for src, features in [
834             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
835             (
836                 "from __future__ import (other, annotations)",
837                 {Feature.FUTURE_ANNOTATIONS},
838             ),
839             ("a = 1 + 2\nfrom something import annotations", set()),
840             ("from __future__ import x, y", set()),
841         ]:
842             with self.subTest(src=src, features=features):
843                 node = black.lib2to3_parse(src)
844                 future_imports = black.get_future_imports(node)
845                 self.assertEqual(
846                     black.get_features_used(node, future_imports=future_imports),
847                     features,
848                 )
849
850     def test_get_future_imports(self) -> None:
851         node = black.lib2to3_parse("\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse("from __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
856         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
858         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse(
860             "from __future__ import multiple\nfrom __future__ import imports\n"
861         )
862         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
863         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
864         self.assertEqual({"black"}, black.get_future_imports(node))
865         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
868         self.assertEqual(set(), black.get_future_imports(node))
869         node = black.lib2to3_parse("from some.module import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse(
872             "from __future__ import unicode_literals as _unicode_literals"
873         )
874         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
875         node = black.lib2to3_parse(
876             "from __future__ import unicode_literals as _lol, print"
877         )
878         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
879
880     @pytest.mark.incompatible_with_mypyc
881     def test_debug_visitor(self) -> None:
882         source, _ = read_data("debug_visitor.py")
883         expected, _ = read_data("debug_visitor.out")
884         out_lines = []
885         err_lines = []
886
887         def out(msg: str, **kwargs: Any) -> None:
888             out_lines.append(msg)
889
890         def err(msg: str, **kwargs: Any) -> None:
891             err_lines.append(msg)
892
893         with patch("black.debug.out", out):
894             DebugVisitor.show(source)
895         actual = "\n".join(out_lines) + "\n"
896         log_name = ""
897         if expected != actual:
898             log_name = black.dump_to_file(*out_lines)
899         self.assertEqual(
900             expected,
901             actual,
902             f"AST print out is different. Actual version dumped to {log_name}",
903         )
904
905     def test_format_file_contents(self) -> None:
906         empty = ""
907         mode = DEFAULT_MODE
908         with self.assertRaises(black.NothingChanged):
909             black.format_file_contents(empty, mode=mode, fast=False)
910         just_nl = "\n"
911         with self.assertRaises(black.NothingChanged):
912             black.format_file_contents(just_nl, mode=mode, fast=False)
913         same = "j = [1, 2, 3]\n"
914         with self.assertRaises(black.NothingChanged):
915             black.format_file_contents(same, mode=mode, fast=False)
916         different = "j = [1,2,3]"
917         expected = same
918         actual = black.format_file_contents(different, mode=mode, fast=False)
919         self.assertEqual(expected, actual)
920         invalid = "return if you can"
921         with self.assertRaises(black.InvalidInput) as e:
922             black.format_file_contents(invalid, mode=mode, fast=False)
923         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
924
925     def test_endmarker(self) -> None:
926         n = black.lib2to3_parse("\n")
927         self.assertEqual(n.type, black.syms.file_input)
928         self.assertEqual(len(n.children), 1)
929         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
930
931     @pytest.mark.incompatible_with_mypyc
932     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
933     def test_assertFormatEqual(self) -> None:
934         out_lines = []
935         err_lines = []
936
937         def out(msg: str, **kwargs: Any) -> None:
938             out_lines.append(msg)
939
940         def err(msg: str, **kwargs: Any) -> None:
941             err_lines.append(msg)
942
943         with patch("black.output._out", out), patch("black.output._err", err):
944             with self.assertRaises(AssertionError):
945                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
946
947         out_str = "".join(out_lines)
948         self.assertIn("Expected tree:", out_str)
949         self.assertIn("Actual tree:", out_str)
950         self.assertEqual("".join(err_lines), "")
951
952     @event_loop()
953     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
954     def test_works_in_mono_process_only_environment(self) -> None:
955         with cache_dir() as workspace:
956             for f in [
957                 (workspace / "one.py").resolve(),
958                 (workspace / "two.py").resolve(),
959             ]:
960                 f.write_text('print("hello")\n')
961             self.invokeBlack([str(workspace)])
962
963     @event_loop()
964     def test_check_diff_use_together(self) -> None:
965         with cache_dir():
966             # Files which will be reformatted.
967             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
968             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
969             # Files which will not be reformatted.
970             src2 = (THIS_DIR / "data" / "composition.py").resolve()
971             self.invokeBlack([str(src2), "--diff", "--check"])
972             # Multi file command.
973             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
974
975     def test_no_files(self) -> None:
976         with cache_dir():
977             # Without an argument, black exits with error code 0.
978             self.invokeBlack([])
979
980     def test_broken_symlink(self) -> None:
981         with cache_dir() as workspace:
982             symlink = workspace / "broken_link.py"
983             try:
984                 symlink.symlink_to("nonexistent.py")
985             except (OSError, NotImplementedError) as e:
986                 self.skipTest(f"Can't create symlinks: {e}")
987             self.invokeBlack([str(workspace.resolve())])
988
989     def test_single_file_force_pyi(self) -> None:
990         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
991         contents, expected = read_data("force_pyi")
992         with cache_dir() as workspace:
993             path = (workspace / "file.py").resolve()
994             with open(path, "w") as fh:
995                 fh.write(contents)
996             self.invokeBlack([str(path), "--pyi"])
997             with open(path, "r") as fh:
998                 actual = fh.read()
999             # verify cache with --pyi is separate
1000             pyi_cache = black.read_cache(pyi_mode)
1001             self.assertIn(str(path), pyi_cache)
1002             normal_cache = black.read_cache(DEFAULT_MODE)
1003             self.assertNotIn(str(path), normal_cache)
1004         self.assertFormatEqual(expected, actual)
1005         black.assert_equivalent(contents, actual)
1006         black.assert_stable(contents, actual, pyi_mode)
1007
1008     @event_loop()
1009     def test_multi_file_force_pyi(self) -> None:
1010         reg_mode = DEFAULT_MODE
1011         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1012         contents, expected = read_data("force_pyi")
1013         with cache_dir() as workspace:
1014             paths = [
1015                 (workspace / "file1.py").resolve(),
1016                 (workspace / "file2.py").resolve(),
1017             ]
1018             for path in paths:
1019                 with open(path, "w") as fh:
1020                     fh.write(contents)
1021             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1022             for path in paths:
1023                 with open(path, "r") as fh:
1024                     actual = fh.read()
1025                 self.assertEqual(actual, expected)
1026             # verify cache with --pyi is separate
1027             pyi_cache = black.read_cache(pyi_mode)
1028             normal_cache = black.read_cache(reg_mode)
1029             for path in paths:
1030                 self.assertIn(str(path), pyi_cache)
1031                 self.assertNotIn(str(path), normal_cache)
1032
1033     def test_pipe_force_pyi(self) -> None:
1034         source, expected = read_data("force_pyi")
1035         result = CliRunner().invoke(
1036             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1037         )
1038         self.assertEqual(result.exit_code, 0)
1039         actual = result.output
1040         self.assertFormatEqual(actual, expected)
1041
1042     def test_single_file_force_py36(self) -> None:
1043         reg_mode = DEFAULT_MODE
1044         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1045         source, expected = read_data("force_py36")
1046         with cache_dir() as workspace:
1047             path = (workspace / "file.py").resolve()
1048             with open(path, "w") as fh:
1049                 fh.write(source)
1050             self.invokeBlack([str(path), *PY36_ARGS])
1051             with open(path, "r") as fh:
1052                 actual = fh.read()
1053             # verify cache with --target-version is separate
1054             py36_cache = black.read_cache(py36_mode)
1055             self.assertIn(str(path), py36_cache)
1056             normal_cache = black.read_cache(reg_mode)
1057             self.assertNotIn(str(path), normal_cache)
1058         self.assertEqual(actual, expected)
1059
1060     @event_loop()
1061     def test_multi_file_force_py36(self) -> None:
1062         reg_mode = DEFAULT_MODE
1063         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1064         source, expected = read_data("force_py36")
1065         with cache_dir() as workspace:
1066             paths = [
1067                 (workspace / "file1.py").resolve(),
1068                 (workspace / "file2.py").resolve(),
1069             ]
1070             for path in paths:
1071                 with open(path, "w") as fh:
1072                     fh.write(source)
1073             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1074             for path in paths:
1075                 with open(path, "r") as fh:
1076                     actual = fh.read()
1077                 self.assertEqual(actual, expected)
1078             # verify cache with --target-version is separate
1079             pyi_cache = black.read_cache(py36_mode)
1080             normal_cache = black.read_cache(reg_mode)
1081             for path in paths:
1082                 self.assertIn(str(path), pyi_cache)
1083                 self.assertNotIn(str(path), normal_cache)
1084
1085     def test_pipe_force_py36(self) -> None:
1086         source, expected = read_data("force_py36")
1087         result = CliRunner().invoke(
1088             black.main,
1089             ["-", "-q", "--target-version=py36"],
1090             input=BytesIO(source.encode("utf8")),
1091         )
1092         self.assertEqual(result.exit_code, 0)
1093         actual = result.output
1094         self.assertFormatEqual(actual, expected)
1095
1096     @pytest.mark.incompatible_with_mypyc
1097     def test_reformat_one_with_stdin(self) -> None:
1098         with patch(
1099             "black.format_stdin_to_stdout",
1100             return_value=lambda *args, **kwargs: black.Changed.YES,
1101         ) as fsts:
1102             report = MagicMock()
1103             path = Path("-")
1104             black.reformat_one(
1105                 path,
1106                 fast=True,
1107                 write_back=black.WriteBack.YES,
1108                 mode=DEFAULT_MODE,
1109                 report=report,
1110             )
1111             fsts.assert_called_once()
1112             report.done.assert_called_with(path, black.Changed.YES)
1113
1114     @pytest.mark.incompatible_with_mypyc
1115     def test_reformat_one_with_stdin_filename(self) -> None:
1116         with patch(
1117             "black.format_stdin_to_stdout",
1118             return_value=lambda *args, **kwargs: black.Changed.YES,
1119         ) as fsts:
1120             report = MagicMock()
1121             p = "foo.py"
1122             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1123             expected = Path(p)
1124             black.reformat_one(
1125                 path,
1126                 fast=True,
1127                 write_back=black.WriteBack.YES,
1128                 mode=DEFAULT_MODE,
1129                 report=report,
1130             )
1131             fsts.assert_called_once_with(
1132                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1133             )
1134             # __BLACK_STDIN_FILENAME__ should have been stripped
1135             report.done.assert_called_with(expected, black.Changed.YES)
1136
1137     @pytest.mark.incompatible_with_mypyc
1138     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1139         with patch(
1140             "black.format_stdin_to_stdout",
1141             return_value=lambda *args, **kwargs: black.Changed.YES,
1142         ) as fsts:
1143             report = MagicMock()
1144             p = "foo.pyi"
1145             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1146             expected = Path(p)
1147             black.reformat_one(
1148                 path,
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=DEFAULT_MODE,
1152                 report=report,
1153             )
1154             fsts.assert_called_once_with(
1155                 fast=True,
1156                 write_back=black.WriteBack.YES,
1157                 mode=replace(DEFAULT_MODE, is_pyi=True),
1158             )
1159             # __BLACK_STDIN_FILENAME__ should have been stripped
1160             report.done.assert_called_with(expected, black.Changed.YES)
1161
1162     @pytest.mark.incompatible_with_mypyc
1163     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1164         with patch(
1165             "black.format_stdin_to_stdout",
1166             return_value=lambda *args, **kwargs: black.Changed.YES,
1167         ) as fsts:
1168             report = MagicMock()
1169             p = "foo.ipynb"
1170             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1171             expected = Path(p)
1172             black.reformat_one(
1173                 path,
1174                 fast=True,
1175                 write_back=black.WriteBack.YES,
1176                 mode=DEFAULT_MODE,
1177                 report=report,
1178             )
1179             fsts.assert_called_once_with(
1180                 fast=True,
1181                 write_back=black.WriteBack.YES,
1182                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1183             )
1184             # __BLACK_STDIN_FILENAME__ should have been stripped
1185             report.done.assert_called_with(expected, black.Changed.YES)
1186
1187     @pytest.mark.incompatible_with_mypyc
1188     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1189         with patch(
1190             "black.format_stdin_to_stdout",
1191             return_value=lambda *args, **kwargs: black.Changed.YES,
1192         ) as fsts:
1193             report = MagicMock()
1194             # Even with an existing file, since we are forcing stdin, black
1195             # should output to stdout and not modify the file inplace
1196             p = Path(str(THIS_DIR / "data/collections.py"))
1197             # Make sure is_file actually returns True
1198             self.assertTrue(p.is_file())
1199             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1200             expected = Path(p)
1201             black.reformat_one(
1202                 path,
1203                 fast=True,
1204                 write_back=black.WriteBack.YES,
1205                 mode=DEFAULT_MODE,
1206                 report=report,
1207             )
1208             fsts.assert_called_once()
1209             # __BLACK_STDIN_FILENAME__ should have been stripped
1210             report.done.assert_called_with(expected, black.Changed.YES)
1211
1212     def test_reformat_one_with_stdin_empty(self) -> None:
1213         output = io.StringIO()
1214         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1215             try:
1216                 black.format_stdin_to_stdout(
1217                     fast=True,
1218                     content="",
1219                     write_back=black.WriteBack.YES,
1220                     mode=DEFAULT_MODE,
1221                 )
1222             except io.UnsupportedOperation:
1223                 pass  # StringIO does not support detach
1224             assert output.getvalue() == ""
1225
1226     def test_invalid_cli_regex(self) -> None:
1227         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1228             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1229
1230     def test_required_version_matches_version(self) -> None:
1231         self.invokeBlack(
1232             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1233         )
1234
1235     def test_required_version_does_not_match_version(self) -> None:
1236         self.invokeBlack(
1237             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1238         )
1239
1240     def test_preserves_line_endings(self) -> None:
1241         with TemporaryDirectory() as workspace:
1242             test_file = Path(workspace) / "test.py"
1243             for nl in ["\n", "\r\n"]:
1244                 contents = nl.join(["def f(  ):", "    pass"])
1245                 test_file.write_bytes(contents.encode())
1246                 ff(test_file, write_back=black.WriteBack.YES)
1247                 updated_contents: bytes = test_file.read_bytes()
1248                 self.assertIn(nl.encode(), updated_contents)
1249                 if nl == "\n":
1250                     self.assertNotIn(b"\r\n", updated_contents)
1251
1252     def test_preserves_line_endings_via_stdin(self) -> None:
1253         for nl in ["\n", "\r\n"]:
1254             contents = nl.join(["def f(  ):", "    pass"])
1255             runner = BlackRunner()
1256             result = runner.invoke(
1257                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1258             )
1259             self.assertEqual(result.exit_code, 0)
1260             output = result.stdout_bytes
1261             self.assertIn(nl.encode("utf8"), output)
1262             if nl == "\n":
1263                 self.assertNotIn(b"\r\n", output)
1264
1265     def test_assert_equivalent_different_asts(self) -> None:
1266         with self.assertRaises(AssertionError):
1267             black.assert_equivalent("{}", "None")
1268
1269     def test_shhh_click(self) -> None:
1270         try:
1271             from click import _unicodefun
1272         except ModuleNotFoundError:
1273             self.skipTest("Incompatible Click version")
1274         if not hasattr(_unicodefun, "_verify_python3_env"):
1275             self.skipTest("Incompatible Click version")
1276         # First, let's see if Click is crashing with a preferred ASCII charset.
1277         with patch("locale.getpreferredencoding") as gpe:
1278             gpe.return_value = "ASCII"
1279             with self.assertRaises(RuntimeError):
1280                 _unicodefun._verify_python3_env()  # type: ignore
1281         # Now, let's silence Click...
1282         black.patch_click()
1283         # ...and confirm it's silent.
1284         with patch("locale.getpreferredencoding") as gpe:
1285             gpe.return_value = "ASCII"
1286             try:
1287                 _unicodefun._verify_python3_env()  # type: ignore
1288             except RuntimeError as re:
1289                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1290
1291     def test_root_logger_not_used_directly(self) -> None:
1292         def fail(*args: Any, **kwargs: Any) -> None:
1293             self.fail("Record created with root logger")
1294
1295         with patch.multiple(
1296             logging.root,
1297             debug=fail,
1298             info=fail,
1299             warning=fail,
1300             error=fail,
1301             critical=fail,
1302             log=fail,
1303         ):
1304             ff(THIS_DIR / "util.py")
1305
1306     def test_invalid_config_return_code(self) -> None:
1307         tmp_file = Path(black.dump_to_file())
1308         try:
1309             tmp_config = Path(black.dump_to_file())
1310             tmp_config.unlink()
1311             args = ["--config", str(tmp_config), str(tmp_file)]
1312             self.invokeBlack(args, exit_code=2, ignore_config=False)
1313         finally:
1314             tmp_file.unlink()
1315
1316     def test_parse_pyproject_toml(self) -> None:
1317         test_toml_file = THIS_DIR / "test.toml"
1318         config = black.parse_pyproject_toml(str(test_toml_file))
1319         self.assertEqual(config["verbose"], 1)
1320         self.assertEqual(config["check"], "no")
1321         self.assertEqual(config["diff"], "y")
1322         self.assertEqual(config["color"], True)
1323         self.assertEqual(config["line_length"], 79)
1324         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1325         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1326         self.assertEqual(config["exclude"], r"\.pyi?$")
1327         self.assertEqual(config["include"], r"\.py?$")
1328
1329     def test_read_pyproject_toml(self) -> None:
1330         test_toml_file = THIS_DIR / "test.toml"
1331         fake_ctx = FakeContext()
1332         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1333         config = fake_ctx.default_map
1334         self.assertEqual(config["verbose"], "1")
1335         self.assertEqual(config["check"], "no")
1336         self.assertEqual(config["diff"], "y")
1337         self.assertEqual(config["color"], "True")
1338         self.assertEqual(config["line_length"], "79")
1339         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1340         self.assertEqual(config["exclude"], r"\.pyi?$")
1341         self.assertEqual(config["include"], r"\.py?$")
1342
1343     @pytest.mark.incompatible_with_mypyc
1344     def test_find_project_root(self) -> None:
1345         with TemporaryDirectory() as workspace:
1346             root = Path(workspace)
1347             test_dir = root / "test"
1348             test_dir.mkdir()
1349
1350             src_dir = root / "src"
1351             src_dir.mkdir()
1352
1353             root_pyproject = root / "pyproject.toml"
1354             root_pyproject.touch()
1355             src_pyproject = src_dir / "pyproject.toml"
1356             src_pyproject.touch()
1357             src_python = src_dir / "foo.py"
1358             src_python.touch()
1359
1360             self.assertEqual(
1361                 black.find_project_root((src_dir, test_dir)),
1362                 (root.resolve(), "pyproject.toml"),
1363             )
1364             self.assertEqual(
1365                 black.find_project_root((src_dir,)),
1366                 (src_dir.resolve(), "pyproject.toml"),
1367             )
1368             self.assertEqual(
1369                 black.find_project_root((src_python,)),
1370                 (src_dir.resolve(), "pyproject.toml"),
1371             )
1372
1373     @patch(
1374         "black.files.find_user_pyproject_toml",
1375         black.files.find_user_pyproject_toml.__wrapped__,
1376     )
1377     def test_find_user_pyproject_toml_linux(self) -> None:
1378         if system() == "Windows":
1379             return
1380
1381         # Test if XDG_CONFIG_HOME is checked
1382         with TemporaryDirectory() as workspace:
1383             tmp_user_config = Path(workspace) / "black"
1384             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1385                 self.assertEqual(
1386                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1387                 )
1388
1389         # Test fallback for XDG_CONFIG_HOME
1390         with patch.dict("os.environ"):
1391             os.environ.pop("XDG_CONFIG_HOME", None)
1392             fallback_user_config = Path("~/.config").expanduser() / "black"
1393             self.assertEqual(
1394                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1395             )
1396
1397     def test_find_user_pyproject_toml_windows(self) -> None:
1398         if system() != "Windows":
1399             return
1400
1401         user_config_path = Path.home() / ".black"
1402         self.assertEqual(
1403             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1404         )
1405
1406     def test_bpo_33660_workaround(self) -> None:
1407         if system() == "Windows":
1408             return
1409
1410         # https://bugs.python.org/issue33660
1411         root = Path("/")
1412         with change_directory(root):
1413             path = Path("workspace") / "project"
1414             report = black.Report(verbose=True)
1415             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1416             self.assertEqual(normalized_path, "workspace/project")
1417
1418     def test_newline_comment_interaction(self) -> None:
1419         source = "class A:\\\r\n# type: ignore\n pass\n"
1420         output = black.format_str(source, mode=DEFAULT_MODE)
1421         black.assert_stable(source, output, mode=DEFAULT_MODE)
1422
1423     def test_bpo_2142_workaround(self) -> None:
1424
1425         # https://bugs.python.org/issue2142
1426
1427         source, _ = read_data("missing_final_newline.py")
1428         # read_data adds a trailing newline
1429         source = source.rstrip()
1430         expected, _ = read_data("missing_final_newline.diff")
1431         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1432         diff_header = re.compile(
1433             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1434             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1435         )
1436         try:
1437             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1438             self.assertEqual(result.exit_code, 0)
1439         finally:
1440             os.unlink(tmp_file)
1441         actual = result.output
1442         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1443         self.assertEqual(actual, expected)
1444
1445     @staticmethod
1446     def compare_results(
1447         result: click.testing.Result, expected_value: str, expected_exit_code: int
1448     ) -> None:
1449         """Helper method to test the value and exit code of a click Result."""
1450         assert (
1451             result.output == expected_value
1452         ), "The output did not match the expected value."
1453         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1454
1455     def test_code_option(self) -> None:
1456         """Test the code option with no changes."""
1457         code = 'print("Hello world")\n'
1458         args = ["--code", code]
1459         result = CliRunner().invoke(black.main, args)
1460
1461         self.compare_results(result, code, 0)
1462
1463     def test_code_option_changed(self) -> None:
1464         """Test the code option when changes are required."""
1465         code = "print('hello world')"
1466         formatted = black.format_str(code, mode=DEFAULT_MODE)
1467
1468         args = ["--code", code]
1469         result = CliRunner().invoke(black.main, args)
1470
1471         self.compare_results(result, formatted, 0)
1472
1473     def test_code_option_check(self) -> None:
1474         """Test the code option when check is passed."""
1475         args = ["--check", "--code", 'print("Hello world")\n']
1476         result = CliRunner().invoke(black.main, args)
1477         self.compare_results(result, "", 0)
1478
1479     def test_code_option_check_changed(self) -> None:
1480         """Test the code option when changes are required, and check is passed."""
1481         args = ["--check", "--code", "print('hello world')"]
1482         result = CliRunner().invoke(black.main, args)
1483         self.compare_results(result, "", 1)
1484
1485     def test_code_option_diff(self) -> None:
1486         """Test the code option when diff is passed."""
1487         code = "print('hello world')"
1488         formatted = black.format_str(code, mode=DEFAULT_MODE)
1489         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1490
1491         args = ["--diff", "--code", code]
1492         result = CliRunner().invoke(black.main, args)
1493
1494         # Remove time from diff
1495         output = DIFF_TIME.sub("", result.output)
1496
1497         assert output == result_diff, "The output did not match the expected value."
1498         assert result.exit_code == 0, "The exit code is incorrect."
1499
1500     def test_code_option_color_diff(self) -> None:
1501         """Test the code option when color and diff are passed."""
1502         code = "print('hello world')"
1503         formatted = black.format_str(code, mode=DEFAULT_MODE)
1504
1505         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1506         result_diff = color_diff(result_diff)
1507
1508         args = ["--diff", "--color", "--code", code]
1509         result = CliRunner().invoke(black.main, args)
1510
1511         # Remove time from diff
1512         output = DIFF_TIME.sub("", result.output)
1513
1514         assert output == result_diff, "The output did not match the expected value."
1515         assert result.exit_code == 0, "The exit code is incorrect."
1516
1517     @pytest.mark.incompatible_with_mypyc
1518     def test_code_option_safe(self) -> None:
1519         """Test that the code option throws an error when the sanity checks fail."""
1520         # Patch black.assert_equivalent to ensure the sanity checks fail
1521         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1522             code = 'print("Hello world")'
1523             error_msg = f"{code}\nerror: cannot format <string>: \n"
1524
1525             args = ["--safe", "--code", code]
1526             result = CliRunner().invoke(black.main, args)
1527
1528             self.compare_results(result, error_msg, 123)
1529
1530     def test_code_option_fast(self) -> None:
1531         """Test that the code option ignores errors when the sanity checks fail."""
1532         # Patch black.assert_equivalent to ensure the sanity checks fail
1533         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1534             code = 'print("Hello world")'
1535             formatted = black.format_str(code, mode=DEFAULT_MODE)
1536
1537             args = ["--fast", "--code", code]
1538             result = CliRunner().invoke(black.main, args)
1539
1540             self.compare_results(result, formatted, 0)
1541
1542     @pytest.mark.incompatible_with_mypyc
1543     def test_code_option_config(self) -> None:
1544         """
1545         Test that the code option finds the pyproject.toml in the current directory.
1546         """
1547         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1548             args = ["--code", "print"]
1549             # This is the only directory known to contain a pyproject.toml
1550             with change_directory(PROJECT_ROOT):
1551                 CliRunner().invoke(black.main, args)
1552                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1553
1554             assert (
1555                 len(parse.mock_calls) >= 1
1556             ), "Expected config parse to be called with the current directory."
1557
1558             _, call_args, _ = parse.mock_calls[0]
1559             assert (
1560                 call_args[0].lower() == str(pyproject_path).lower()
1561             ), "Incorrect config loaded."
1562
1563     @pytest.mark.incompatible_with_mypyc
1564     def test_code_option_parent_config(self) -> None:
1565         """
1566         Test that the code option finds the pyproject.toml in the parent directory.
1567         """
1568         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1569             with change_directory(THIS_DIR):
1570                 args = ["--code", "print"]
1571                 CliRunner().invoke(black.main, args)
1572
1573                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1574                 assert (
1575                     len(parse.mock_calls) >= 1
1576                 ), "Expected config parse to be called with the current directory."
1577
1578                 _, call_args, _ = parse.mock_calls[0]
1579                 assert (
1580                     call_args[0].lower() == str(pyproject_path).lower()
1581                 ), "Incorrect config loaded."
1582
1583     def test_for_handled_unexpected_eof_error(self) -> None:
1584         """
1585         Test that an unexpected EOF SyntaxError is nicely presented.
1586         """
1587         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1588             black.lib2to3_parse("print(", {})
1589
1590         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1591
1592     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1593         with pytest.raises(AssertionError) as err:
1594             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1595
1596         err.match("--safe")
1597         # Unfortunately the SyntaxError message has changed in newer versions so we
1598         # can't match it directly.
1599         err.match("invalid character")
1600         err.match(r"\(<unknown>, line 1\)")
1601
1602
1603 class TestCaching:
1604     def test_cache_broken_file(self) -> None:
1605         mode = DEFAULT_MODE
1606         with cache_dir() as workspace:
1607             cache_file = get_cache_file(mode)
1608             cache_file.write_text("this is not a pickle")
1609             assert black.read_cache(mode) == {}
1610             src = (workspace / "test.py").resolve()
1611             src.write_text("print('hello')")
1612             invokeBlack([str(src)])
1613             cache = black.read_cache(mode)
1614             assert str(src) in cache
1615
1616     def test_cache_single_file_already_cached(self) -> None:
1617         mode = DEFAULT_MODE
1618         with cache_dir() as workspace:
1619             src = (workspace / "test.py").resolve()
1620             src.write_text("print('hello')")
1621             black.write_cache({}, [src], mode)
1622             invokeBlack([str(src)])
1623             assert src.read_text() == "print('hello')"
1624
1625     @event_loop()
1626     def test_cache_multiple_files(self) -> None:
1627         mode = DEFAULT_MODE
1628         with cache_dir() as workspace, patch(
1629             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1630         ):
1631             one = (workspace / "one.py").resolve()
1632             with one.open("w") as fobj:
1633                 fobj.write("print('hello')")
1634             two = (workspace / "two.py").resolve()
1635             with two.open("w") as fobj:
1636                 fobj.write("print('hello')")
1637             black.write_cache({}, [one], mode)
1638             invokeBlack([str(workspace)])
1639             with one.open("r") as fobj:
1640                 assert fobj.read() == "print('hello')"
1641             with two.open("r") as fobj:
1642                 assert fobj.read() == 'print("hello")\n'
1643             cache = black.read_cache(mode)
1644             assert str(one) in cache
1645             assert str(two) in cache
1646
1647     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1648     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1649         mode = DEFAULT_MODE
1650         with cache_dir() as workspace:
1651             src = (workspace / "test.py").resolve()
1652             with src.open("w") as fobj:
1653                 fobj.write("print('hello')")
1654             with patch("black.read_cache") as read_cache, patch(
1655                 "black.write_cache"
1656             ) as write_cache:
1657                 cmd = [str(src), "--diff"]
1658                 if color:
1659                     cmd.append("--color")
1660                 invokeBlack(cmd)
1661                 cache_file = get_cache_file(mode)
1662                 assert cache_file.exists() is False
1663                 write_cache.assert_not_called()
1664                 read_cache.assert_not_called()
1665
1666     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1667     @event_loop()
1668     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1669         with cache_dir() as workspace:
1670             for tag in range(0, 4):
1671                 src = (workspace / f"test{tag}.py").resolve()
1672                 with src.open("w") as fobj:
1673                     fobj.write("print('hello')")
1674             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1675                 cmd = ["--diff", str(workspace)]
1676                 if color:
1677                     cmd.append("--color")
1678                 invokeBlack(cmd, exit_code=0)
1679                 # this isn't quite doing what we want, but if it _isn't_
1680                 # called then we cannot be using the lock it provides
1681                 mgr.assert_called()
1682
1683     def test_no_cache_when_stdin(self) -> None:
1684         mode = DEFAULT_MODE
1685         with cache_dir():
1686             result = CliRunner().invoke(
1687                 black.main, ["-"], input=BytesIO(b"print('hello')")
1688             )
1689             assert not result.exit_code
1690             cache_file = get_cache_file(mode)
1691             assert not cache_file.exists()
1692
1693     def test_read_cache_no_cachefile(self) -> None:
1694         mode = DEFAULT_MODE
1695         with cache_dir():
1696             assert black.read_cache(mode) == {}
1697
1698     def test_write_cache_read_cache(self) -> None:
1699         mode = DEFAULT_MODE
1700         with cache_dir() as workspace:
1701             src = (workspace / "test.py").resolve()
1702             src.touch()
1703             black.write_cache({}, [src], mode)
1704             cache = black.read_cache(mode)
1705             assert str(src) in cache
1706             assert cache[str(src)] == black.get_cache_info(src)
1707
1708     def test_filter_cached(self) -> None:
1709         with TemporaryDirectory() as workspace:
1710             path = Path(workspace)
1711             uncached = (path / "uncached").resolve()
1712             cached = (path / "cached").resolve()
1713             cached_but_changed = (path / "changed").resolve()
1714             uncached.touch()
1715             cached.touch()
1716             cached_but_changed.touch()
1717             cache = {
1718                 str(cached): black.get_cache_info(cached),
1719                 str(cached_but_changed): (0.0, 0),
1720             }
1721             todo, done = black.filter_cached(
1722                 cache, {uncached, cached, cached_but_changed}
1723             )
1724             assert todo == {uncached, cached_but_changed}
1725             assert done == {cached}
1726
1727     def test_write_cache_creates_directory_if_needed(self) -> None:
1728         mode = DEFAULT_MODE
1729         with cache_dir(exists=False) as workspace:
1730             assert not workspace.exists()
1731             black.write_cache({}, [], mode)
1732             assert workspace.exists()
1733
1734     @event_loop()
1735     def test_failed_formatting_does_not_get_cached(self) -> None:
1736         mode = DEFAULT_MODE
1737         with cache_dir() as workspace, patch(
1738             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1739         ):
1740             failing = (workspace / "failing.py").resolve()
1741             with failing.open("w") as fobj:
1742                 fobj.write("not actually python")
1743             clean = (workspace / "clean.py").resolve()
1744             with clean.open("w") as fobj:
1745                 fobj.write('print("hello")\n')
1746             invokeBlack([str(workspace)], exit_code=123)
1747             cache = black.read_cache(mode)
1748             assert str(failing) not in cache
1749             assert str(clean) in cache
1750
1751     def test_write_cache_write_fail(self) -> None:
1752         mode = DEFAULT_MODE
1753         with cache_dir(), patch.object(Path, "open") as mock:
1754             mock.side_effect = OSError
1755             black.write_cache({}, [], mode)
1756
1757     def test_read_cache_line_lengths(self) -> None:
1758         mode = DEFAULT_MODE
1759         short_mode = replace(DEFAULT_MODE, line_length=1)
1760         with cache_dir() as workspace:
1761             path = (workspace / "file.py").resolve()
1762             path.touch()
1763             black.write_cache({}, [path], mode)
1764             one = black.read_cache(mode)
1765             assert str(path) in one
1766             two = black.read_cache(short_mode)
1767             assert str(path) not in two
1768
1769
1770 def assert_collected_sources(
1771     src: Sequence[Union[str, Path]],
1772     expected: Sequence[Union[str, Path]],
1773     *,
1774     ctx: Optional[FakeContext] = None,
1775     exclude: Optional[str] = None,
1776     include: Optional[str] = None,
1777     extend_exclude: Optional[str] = None,
1778     force_exclude: Optional[str] = None,
1779     stdin_filename: Optional[str] = None,
1780 ) -> None:
1781     gs_src = tuple(str(Path(s)) for s in src)
1782     gs_expected = [Path(s) for s in expected]
1783     gs_exclude = None if exclude is None else compile_pattern(exclude)
1784     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1785     gs_extend_exclude = (
1786         None if extend_exclude is None else compile_pattern(extend_exclude)
1787     )
1788     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1789     collected = black.get_sources(
1790         ctx=ctx or FakeContext(),
1791         src=gs_src,
1792         quiet=False,
1793         verbose=False,
1794         include=gs_include,
1795         exclude=gs_exclude,
1796         extend_exclude=gs_extend_exclude,
1797         force_exclude=gs_force_exclude,
1798         report=black.Report(),
1799         stdin_filename=stdin_filename,
1800     )
1801     assert sorted(collected) == sorted(gs_expected)
1802
1803
1804 class TestFileCollection:
1805     def test_include_exclude(self) -> None:
1806         path = THIS_DIR / "data" / "include_exclude_tests"
1807         src = [path]
1808         expected = [
1809             Path(path / "b/dont_exclude/a.py"),
1810             Path(path / "b/dont_exclude/a.pyi"),
1811         ]
1812         assert_collected_sources(
1813             src,
1814             expected,
1815             include=r"\.pyi?$",
1816             exclude=r"/exclude/|/\.definitely_exclude/",
1817         )
1818
1819     def test_gitignore_used_as_default(self) -> None:
1820         base = Path(DATA_DIR / "include_exclude_tests")
1821         expected = [
1822             base / "b/.definitely_exclude/a.py",
1823             base / "b/.definitely_exclude/a.pyi",
1824         ]
1825         src = [base / "b/"]
1826         ctx = FakeContext()
1827         ctx.obj["root"] = base
1828         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1829
1830     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1831     def test_exclude_for_issue_1572(self) -> None:
1832         # Exclude shouldn't touch files that were explicitly given to Black through the
1833         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1834         # https://github.com/psf/black/issues/1572
1835         path = DATA_DIR / "include_exclude_tests"
1836         src = [path / "b/exclude/a.py"]
1837         expected = [path / "b/exclude/a.py"]
1838         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1839
1840     def test_gitignore_exclude(self) -> None:
1841         path = THIS_DIR / "data" / "include_exclude_tests"
1842         include = re.compile(r"\.pyi?$")
1843         exclude = re.compile(r"")
1844         report = black.Report()
1845         gitignore = PathSpec.from_lines(
1846             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1847         )
1848         sources: List[Path] = []
1849         expected = [
1850             Path(path / "b/dont_exclude/a.py"),
1851             Path(path / "b/dont_exclude/a.pyi"),
1852         ]
1853         this_abs = THIS_DIR.resolve()
1854         sources.extend(
1855             black.gen_python_files(
1856                 path.iterdir(),
1857                 this_abs,
1858                 include,
1859                 exclude,
1860                 None,
1861                 None,
1862                 report,
1863                 gitignore,
1864                 verbose=False,
1865                 quiet=False,
1866             )
1867         )
1868         assert sorted(expected) == sorted(sources)
1869
1870     def test_nested_gitignore(self) -> None:
1871         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1872         include = re.compile(r"\.pyi?$")
1873         exclude = re.compile(r"")
1874         root_gitignore = black.files.get_gitignore(path)
1875         report = black.Report()
1876         expected: List[Path] = [
1877             Path(path / "x.py"),
1878             Path(path / "root/b.py"),
1879             Path(path / "root/c.py"),
1880             Path(path / "root/child/c.py"),
1881         ]
1882         this_abs = THIS_DIR.resolve()
1883         sources = list(
1884             black.gen_python_files(
1885                 path.iterdir(),
1886                 this_abs,
1887                 include,
1888                 exclude,
1889                 None,
1890                 None,
1891                 report,
1892                 root_gitignore,
1893                 verbose=False,
1894                 quiet=False,
1895             )
1896         )
1897         assert sorted(expected) == sorted(sources)
1898
1899     def test_invalid_gitignore(self) -> None:
1900         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1901         empty_config = path / "pyproject.toml"
1902         result = BlackRunner().invoke(
1903             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1904         )
1905         assert result.exit_code == 1
1906         assert result.stderr_bytes is not None
1907
1908         gitignore = path / ".gitignore"
1909         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1910
1911     def test_invalid_nested_gitignore(self) -> None:
1912         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1913         empty_config = path / "pyproject.toml"
1914         result = BlackRunner().invoke(
1915             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1916         )
1917         assert result.exit_code == 1
1918         assert result.stderr_bytes is not None
1919
1920         gitignore = path / "a" / ".gitignore"
1921         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1922
1923     def test_empty_include(self) -> None:
1924         path = DATA_DIR / "include_exclude_tests"
1925         src = [path]
1926         expected = [
1927             Path(path / "b/exclude/a.pie"),
1928             Path(path / "b/exclude/a.py"),
1929             Path(path / "b/exclude/a.pyi"),
1930             Path(path / "b/dont_exclude/a.pie"),
1931             Path(path / "b/dont_exclude/a.py"),
1932             Path(path / "b/dont_exclude/a.pyi"),
1933             Path(path / "b/.definitely_exclude/a.pie"),
1934             Path(path / "b/.definitely_exclude/a.py"),
1935             Path(path / "b/.definitely_exclude/a.pyi"),
1936             Path(path / ".gitignore"),
1937             Path(path / "pyproject.toml"),
1938         ]
1939         # Setting exclude explicitly to an empty string to block .gitignore usage.
1940         assert_collected_sources(src, expected, include="", exclude="")
1941
1942     def test_extend_exclude(self) -> None:
1943         path = DATA_DIR / "include_exclude_tests"
1944         src = [path]
1945         expected = [
1946             Path(path / "b/exclude/a.py"),
1947             Path(path / "b/dont_exclude/a.py"),
1948         ]
1949         assert_collected_sources(
1950             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1951         )
1952
1953     @pytest.mark.incompatible_with_mypyc
1954     def test_symlink_out_of_root_directory(self) -> None:
1955         path = MagicMock()
1956         root = THIS_DIR.resolve()
1957         child = MagicMock()
1958         include = re.compile(black.DEFAULT_INCLUDES)
1959         exclude = re.compile(black.DEFAULT_EXCLUDES)
1960         report = black.Report()
1961         gitignore = PathSpec.from_lines("gitwildmatch", [])
1962         # `child` should behave like a symlink which resolved path is clearly
1963         # outside of the `root` directory.
1964         path.iterdir.return_value = [child]
1965         child.resolve.return_value = Path("/a/b/c")
1966         child.as_posix.return_value = "/a/b/c"
1967         child.is_symlink.return_value = True
1968         try:
1969             list(
1970                 black.gen_python_files(
1971                     path.iterdir(),
1972                     root,
1973                     include,
1974                     exclude,
1975                     None,
1976                     None,
1977                     report,
1978                     gitignore,
1979                     verbose=False,
1980                     quiet=False,
1981                 )
1982             )
1983         except ValueError as ve:
1984             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1985         path.iterdir.assert_called_once()
1986         child.resolve.assert_called_once()
1987         child.is_symlink.assert_called_once()
1988         # `child` should behave like a strange file which resolved path is clearly
1989         # outside of the `root` directory.
1990         child.is_symlink.return_value = False
1991         with pytest.raises(ValueError):
1992             list(
1993                 black.gen_python_files(
1994                     path.iterdir(),
1995                     root,
1996                     include,
1997                     exclude,
1998                     None,
1999                     None,
2000                     report,
2001                     gitignore,
2002                     verbose=False,
2003                     quiet=False,
2004                 )
2005             )
2006         path.iterdir.assert_called()
2007         assert path.iterdir.call_count == 2
2008         child.resolve.assert_called()
2009         assert child.resolve.call_count == 2
2010         child.is_symlink.assert_called()
2011         assert child.is_symlink.call_count == 2
2012
2013     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2014     def test_get_sources_with_stdin(self) -> None:
2015         src = ["-"]
2016         expected = ["-"]
2017         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2018
2019     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2020     def test_get_sources_with_stdin_filename(self) -> None:
2021         src = ["-"]
2022         stdin_filename = str(THIS_DIR / "data/collections.py")
2023         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2024         assert_collected_sources(
2025             src,
2026             expected,
2027             exclude=r"/exclude/a\.py",
2028             stdin_filename=stdin_filename,
2029         )
2030
2031     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2032     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2033         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2034         # file being passed directly. This is the same as
2035         # test_exclude_for_issue_1572
2036         path = DATA_DIR / "include_exclude_tests"
2037         src = ["-"]
2038         stdin_filename = str(path / "b/exclude/a.py")
2039         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2040         assert_collected_sources(
2041             src,
2042             expected,
2043             exclude=r"/exclude/|a\.py",
2044             stdin_filename=stdin_filename,
2045         )
2046
2047     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2048     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2049         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2050         # file being passed directly. This is the same as
2051         # test_exclude_for_issue_1572
2052         src = ["-"]
2053         path = THIS_DIR / "data" / "include_exclude_tests"
2054         stdin_filename = str(path / "b/exclude/a.py")
2055         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2056         assert_collected_sources(
2057             src,
2058             expected,
2059             extend_exclude=r"/exclude/|a\.py",
2060             stdin_filename=stdin_filename,
2061         )
2062
2063     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2064     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2065         # Force exclude should exclude the file when passing it through
2066         # stdin_filename
2067         path = THIS_DIR / "data" / "include_exclude_tests"
2068         stdin_filename = str(path / "b/exclude/a.py")
2069         assert_collected_sources(
2070             src=["-"],
2071             expected=[],
2072             force_exclude=r"/exclude/|a\.py",
2073             stdin_filename=stdin_filename,
2074         )
2075
2076
2077 try:
2078     with open(black.__file__, "r", encoding="utf-8") as _bf:
2079         black_source_lines = _bf.readlines()
2080 except UnicodeDecodeError:
2081     if not black.COMPILED:
2082         raise
2083
2084
2085 def tracefunc(
2086     frame: types.FrameType, event: str, arg: Any
2087 ) -> Callable[[types.FrameType, str, Any], Any]:
2088     """Show function calls `from black/__init__.py` as they happen.
2089
2090     Register this with `sys.settrace()` in a test you're debugging.
2091     """
2092     if event != "call":
2093         return tracefunc
2094
2095     stack = len(inspect.stack()) - 19
2096     stack *= 2
2097     filename = frame.f_code.co_filename
2098     lineno = frame.f_lineno
2099     func_sig_lineno = lineno - 1
2100     funcname = black_source_lines[func_sig_lineno].strip()
2101     while funcname.startswith("@"):
2102         func_sig_lineno += 1
2103         funcname = black_source_lines[func_sig_lineno].strip()
2104     if "black/__init__.py" in filename:
2105         print(f"{' ' * stack}{lineno}:{funcname}")
2106     return tracefunc