]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Change git url for pip installation in README (#2761)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103         # Dummy root, since most of the tests don't care about it
104         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
105
106
107 class FakeParameter(click.Parameter):
108     """A fake click Parameter for when calling functions that need it."""
109
110     def __init__(self) -> None:
111         pass
112
113
114 class BlackRunner(CliRunner):
115     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
116
117     def __init__(self) -> None:
118         super().__init__(mix_stderr=False)
119
120
121 def invokeBlack(
122     args: List[str], exit_code: int = 0, ignore_config: bool = True
123 ) -> None:
124     runner = BlackRunner()
125     if ignore_config:
126         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
127     result = runner.invoke(black.main, args, catch_exceptions=False)
128     assert result.stdout_bytes is not None
129     assert result.stderr_bytes is not None
130     msg = (
131         f"Failed with args: {args}\n"
132         f"stdout: {result.stdout_bytes.decode()!r}\n"
133         f"stderr: {result.stderr_bytes.decode()!r}\n"
134         f"exception: {result.exception}"
135     )
136     assert result.exit_code == exit_code, msg
137
138
139 class BlackTestCase(BlackBaseTestCase):
140     invokeBlack = staticmethod(invokeBlack)
141
142     def test_empty_ff(self) -> None:
143         expected = ""
144         tmp_file = Path(black.dump_to_file())
145         try:
146             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
147             with open(tmp_file, encoding="utf8") as f:
148                 actual = f.read()
149         finally:
150             os.unlink(tmp_file)
151         self.assertFormatEqual(expected, actual)
152
153     def test_piping(self) -> None:
154         source, expected = read_data("src/black/__init__", data=False)
155         result = BlackRunner().invoke(
156             black.main,
157             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
158             input=BytesIO(source.encode("utf8")),
159         )
160         self.assertEqual(result.exit_code, 0)
161         self.assertFormatEqual(expected, result.output)
162         if source != result.output:
163             black.assert_equivalent(source, result.output)
164             black.assert_stable(source, result.output, DEFAULT_MODE)
165
166     def test_piping_diff(self) -> None:
167         diff_header = re.compile(
168             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
169             r"\+\d\d\d\d"
170         )
171         source, _ = read_data("expression.py")
172         expected, _ = read_data("expression.diff")
173         config = THIS_DIR / "data" / "empty_pyproject.toml"
174         args = [
175             "-",
176             "--fast",
177             f"--line-length={black.DEFAULT_LINE_LENGTH}",
178             "--diff",
179             f"--config={config}",
180         ]
181         result = BlackRunner().invoke(
182             black.main, args, input=BytesIO(source.encode("utf8"))
183         )
184         self.assertEqual(result.exit_code, 0)
185         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
186         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
187         self.assertEqual(expected, actual)
188
189     def test_piping_diff_with_color(self) -> None:
190         source, _ = read_data("expression.py")
191         config = THIS_DIR / "data" / "empty_pyproject.toml"
192         args = [
193             "-",
194             "--fast",
195             f"--line-length={black.DEFAULT_LINE_LENGTH}",
196             "--diff",
197             "--color",
198             f"--config={config}",
199         ]
200         result = BlackRunner().invoke(
201             black.main, args, input=BytesIO(source.encode("utf8"))
202         )
203         actual = result.output
204         # Again, the contents are checked in a different test, so only look for colors.
205         self.assertIn("\033[1m", actual)
206         self.assertIn("\033[36m", actual)
207         self.assertIn("\033[32m", actual)
208         self.assertIn("\033[31m", actual)
209         self.assertIn("\033[0m", actual)
210
211     @patch("black.dump_to_file", dump_to_stderr)
212     def _test_wip(self) -> None:
213         source, expected = read_data("wip")
214         sys.settrace(tracefunc)
215         mode = replace(
216             DEFAULT_MODE,
217             experimental_string_processing=False,
218             target_versions={black.TargetVersion.PY38},
219         )
220         actual = fs(source, mode=mode)
221         sys.settrace(None)
222         self.assertFormatEqual(expected, actual)
223         black.assert_equivalent(source, actual)
224         black.assert_stable(source, actual, black.FileMode())
225
226     @unittest.expectedFailure
227     @patch("black.dump_to_file", dump_to_stderr)
228     def test_trailing_comma_optional_parens_stability1(self) -> None:
229         source, _expected = read_data("trailing_comma_optional_parens1")
230         actual = fs(source)
231         black.assert_stable(source, actual, DEFAULT_MODE)
232
233     @unittest.expectedFailure
234     @patch("black.dump_to_file", dump_to_stderr)
235     def test_trailing_comma_optional_parens_stability2(self) -> None:
236         source, _expected = read_data("trailing_comma_optional_parens2")
237         actual = fs(source)
238         black.assert_stable(source, actual, DEFAULT_MODE)
239
240     @unittest.expectedFailure
241     @patch("black.dump_to_file", dump_to_stderr)
242     def test_trailing_comma_optional_parens_stability3(self) -> None:
243         source, _expected = read_data("trailing_comma_optional_parens3")
244         actual = fs(source)
245         black.assert_stable(source, actual, DEFAULT_MODE)
246
247     @patch("black.dump_to_file", dump_to_stderr)
248     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
249         source, _expected = read_data("trailing_comma_optional_parens1")
250         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
251         black.assert_stable(source, actual, DEFAULT_MODE)
252
253     @patch("black.dump_to_file", dump_to_stderr)
254     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
255         source, _expected = read_data("trailing_comma_optional_parens2")
256         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
257         black.assert_stable(source, actual, DEFAULT_MODE)
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
261         source, _expected = read_data("trailing_comma_optional_parens3")
262         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
263         black.assert_stable(source, actual, DEFAULT_MODE)
264
265     def test_pep_572_version_detection(self) -> None:
266         source, _ = read_data("pep_572")
267         root = black.lib2to3_parse(source)
268         features = black.get_features_used(root)
269         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
270         versions = black.detect_target_versions(root)
271         self.assertIn(black.TargetVersion.PY38, versions)
272
273     def test_expression_ff(self) -> None:
274         source, expected = read_data("expression")
275         tmp_file = Path(black.dump_to_file(source))
276         try:
277             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
278             with open(tmp_file, encoding="utf8") as f:
279                 actual = f.read()
280         finally:
281             os.unlink(tmp_file)
282         self.assertFormatEqual(expected, actual)
283         with patch("black.dump_to_file", dump_to_stderr):
284             black.assert_equivalent(source, actual)
285             black.assert_stable(source, actual, DEFAULT_MODE)
286
287     def test_expression_diff(self) -> None:
288         source, _ = read_data("expression.py")
289         config = THIS_DIR / "data" / "empty_pyproject.toml"
290         expected, _ = read_data("expression.diff")
291         tmp_file = Path(black.dump_to_file(source))
292         diff_header = re.compile(
293             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
294             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
295         )
296         try:
297             result = BlackRunner().invoke(
298                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
299             )
300             self.assertEqual(result.exit_code, 0)
301         finally:
302             os.unlink(tmp_file)
303         actual = result.output
304         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 "Expected diff isn't equal to the actual. If you made changes to"
309                 " expression.py and this is an anticipated difference, overwrite"
310                 f" tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     def test_expression_diff_with_color(self) -> None:
315         source, _ = read_data("expression.py")
316         config = THIS_DIR / "data" / "empty_pyproject.toml"
317         expected, _ = read_data("expression.diff")
318         tmp_file = Path(black.dump_to_file(source))
319         try:
320             result = BlackRunner().invoke(
321                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
322             )
323         finally:
324             os.unlink(tmp_file)
325         actual = result.output
326         # We check the contents of the diff in `test_expression_diff`. All
327         # we need to check here is that color codes exist in the result.
328         self.assertIn("\033[1m", actual)
329         self.assertIn("\033[36m", actual)
330         self.assertIn("\033[32m", actual)
331         self.assertIn("\033[31m", actual)
332         self.assertIn("\033[0m", actual)
333
334     def test_detect_pos_only_arguments(self) -> None:
335         source, _ = read_data("pep_570")
336         root = black.lib2to3_parse(source)
337         features = black.get_features_used(root)
338         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
339         versions = black.detect_target_versions(root)
340         self.assertIn(black.TargetVersion.PY38, versions)
341
342     @patch("black.dump_to_file", dump_to_stderr)
343     def test_string_quotes(self) -> None:
344         source, expected = read_data("string_quotes")
345         mode = black.Mode(experimental_string_processing=True)
346         assert_format(source, expected, mode)
347         mode = replace(mode, string_normalization=False)
348         not_normalized = fs(source, mode=mode)
349         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
350         black.assert_equivalent(source, not_normalized)
351         black.assert_stable(source, not_normalized, mode=mode)
352
353     def test_skip_magic_trailing_comma(self) -> None:
354         source, _ = read_data("expression.py")
355         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
356         tmp_file = Path(black.dump_to_file(source))
357         diff_header = re.compile(
358             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
359             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
360         )
361         try:
362             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
363             self.assertEqual(result.exit_code, 0)
364         finally:
365             os.unlink(tmp_file)
366         actual = result.output
367         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
368         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
369         if expected != actual:
370             dump = black.dump_to_file(actual)
371             msg = (
372                 "Expected diff isn't equal to the actual. If you made changes to"
373                 " expression.py and this is an anticipated difference, overwrite"
374                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
375             )
376             self.assertEqual(expected, actual, msg)
377
378     @patch("black.dump_to_file", dump_to_stderr)
379     def test_async_as_identifier(self) -> None:
380         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
381         source, expected = read_data("async_as_identifier")
382         actual = fs(source)
383         self.assertFormatEqual(expected, actual)
384         major, minor = sys.version_info[:2]
385         if major < 3 or (major <= 3 and minor < 7):
386             black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, DEFAULT_MODE)
388         # ensure black can parse this when the target is 3.6
389         self.invokeBlack([str(source_path), "--target-version", "py36"])
390         # but not on 3.7, because async/await is no longer an identifier
391         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
392
393     @patch("black.dump_to_file", dump_to_stderr)
394     def test_python37(self) -> None:
395         source_path = (THIS_DIR / "data" / "python37.py").resolve()
396         source, expected = read_data("python37")
397         actual = fs(source)
398         self.assertFormatEqual(expected, actual)
399         major, minor = sys.version_info[:2]
400         if major > 3 or (major == 3 and minor >= 7):
401             black.assert_equivalent(source, actual)
402         black.assert_stable(source, actual, DEFAULT_MODE)
403         # ensure black can parse this when the target is 3.7
404         self.invokeBlack([str(source_path), "--target-version", "py37"])
405         # but not on 3.6, because we use async as a reserved keyword
406         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
407
408     def test_tab_comment_indentation(self) -> None:
409         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
410         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
411         self.assertFormatEqual(contents_spc, fs(contents_spc))
412         self.assertFormatEqual(contents_spc, fs(contents_tab))
413
414         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
415         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
416         self.assertFormatEqual(contents_spc, fs(contents_spc))
417         self.assertFormatEqual(contents_spc, fs(contents_tab))
418
419         # mixed tabs and spaces (valid Python 2 code)
420         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
421         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
422         self.assertFormatEqual(contents_spc, fs(contents_spc))
423         self.assertFormatEqual(contents_spc, fs(contents_tab))
424
425         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
426         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
427         self.assertFormatEqual(contents_spc, fs(contents_spc))
428         self.assertFormatEqual(contents_spc, fs(contents_tab))
429
430     def test_report_verbose(self) -> None:
431         report = Report(verbose=True)
432         out_lines = []
433         err_lines = []
434
435         def out(msg: str, **kwargs: Any) -> None:
436             out_lines.append(msg)
437
438         def err(msg: str, **kwargs: Any) -> None:
439             err_lines.append(msg)
440
441         with patch("black.output._out", out), patch("black.output._err", err):
442             report.done(Path("f1"), black.Changed.NO)
443             self.assertEqual(len(out_lines), 1)
444             self.assertEqual(len(err_lines), 0)
445             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
446             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
447             self.assertEqual(report.return_code, 0)
448             report.done(Path("f2"), black.Changed.YES)
449             self.assertEqual(len(out_lines), 2)
450             self.assertEqual(len(err_lines), 0)
451             self.assertEqual(out_lines[-1], "reformatted f2")
452             self.assertEqual(
453                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
454             )
455             report.done(Path("f3"), black.Changed.CACHED)
456             self.assertEqual(len(out_lines), 3)
457             self.assertEqual(len(err_lines), 0)
458             self.assertEqual(
459                 out_lines[-1], "f3 wasn't modified on disk since last run."
460             )
461             self.assertEqual(
462                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
463             )
464             self.assertEqual(report.return_code, 0)
465             report.check = True
466             self.assertEqual(report.return_code, 1)
467             report.check = False
468             report.failed(Path("e1"), "boom")
469             self.assertEqual(len(out_lines), 3)
470             self.assertEqual(len(err_lines), 1)
471             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
472             self.assertEqual(
473                 unstyle(str(report)),
474                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
475                 " reformat.",
476             )
477             self.assertEqual(report.return_code, 123)
478             report.done(Path("f3"), black.Changed.YES)
479             self.assertEqual(len(out_lines), 4)
480             self.assertEqual(len(err_lines), 1)
481             self.assertEqual(out_lines[-1], "reformatted f3")
482             self.assertEqual(
483                 unstyle(str(report)),
484                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
485                 " reformat.",
486             )
487             self.assertEqual(report.return_code, 123)
488             report.failed(Path("e2"), "boom")
489             self.assertEqual(len(out_lines), 4)
490             self.assertEqual(len(err_lines), 2)
491             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
492             self.assertEqual(
493                 unstyle(str(report)),
494                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
495                 " reformat.",
496             )
497             self.assertEqual(report.return_code, 123)
498             report.path_ignored(Path("wat"), "no match")
499             self.assertEqual(len(out_lines), 5)
500             self.assertEqual(len(err_lines), 2)
501             self.assertEqual(out_lines[-1], "wat ignored: no match")
502             self.assertEqual(
503                 unstyle(str(report)),
504                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
505                 " reformat.",
506             )
507             self.assertEqual(report.return_code, 123)
508             report.done(Path("f4"), black.Changed.NO)
509             self.assertEqual(len(out_lines), 6)
510             self.assertEqual(len(err_lines), 2)
511             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
512             self.assertEqual(
513                 unstyle(str(report)),
514                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
515                 " reformat.",
516             )
517             self.assertEqual(report.return_code, 123)
518             report.check = True
519             self.assertEqual(
520                 unstyle(str(report)),
521                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
522                 " would fail to reformat.",
523             )
524             report.check = False
525             report.diff = True
526             self.assertEqual(
527                 unstyle(str(report)),
528                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
529                 " would fail to reformat.",
530             )
531
532     def test_report_quiet(self) -> None:
533         report = Report(quiet=True)
534         out_lines = []
535         err_lines = []
536
537         def out(msg: str, **kwargs: Any) -> None:
538             out_lines.append(msg)
539
540         def err(msg: str, **kwargs: Any) -> None:
541             err_lines.append(msg)
542
543         with patch("black.output._out", out), patch("black.output._err", err):
544             report.done(Path("f1"), black.Changed.NO)
545             self.assertEqual(len(out_lines), 0)
546             self.assertEqual(len(err_lines), 0)
547             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
548             self.assertEqual(report.return_code, 0)
549             report.done(Path("f2"), black.Changed.YES)
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(
553                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
554             )
555             report.done(Path("f3"), black.Changed.CACHED)
556             self.assertEqual(len(out_lines), 0)
557             self.assertEqual(len(err_lines), 0)
558             self.assertEqual(
559                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
560             )
561             self.assertEqual(report.return_code, 0)
562             report.check = True
563             self.assertEqual(report.return_code, 1)
564             report.check = False
565             report.failed(Path("e1"), "boom")
566             self.assertEqual(len(out_lines), 0)
567             self.assertEqual(len(err_lines), 1)
568             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
569             self.assertEqual(
570                 unstyle(str(report)),
571                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
572                 " reformat.",
573             )
574             self.assertEqual(report.return_code, 123)
575             report.done(Path("f3"), black.Changed.YES)
576             self.assertEqual(len(out_lines), 0)
577             self.assertEqual(len(err_lines), 1)
578             self.assertEqual(
579                 unstyle(str(report)),
580                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
581                 " reformat.",
582             )
583             self.assertEqual(report.return_code, 123)
584             report.failed(Path("e2"), "boom")
585             self.assertEqual(len(out_lines), 0)
586             self.assertEqual(len(err_lines), 2)
587             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
588             self.assertEqual(
589                 unstyle(str(report)),
590                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
591                 " reformat.",
592             )
593             self.assertEqual(report.return_code, 123)
594             report.path_ignored(Path("wat"), "no match")
595             self.assertEqual(len(out_lines), 0)
596             self.assertEqual(len(err_lines), 2)
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
600                 " reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.done(Path("f4"), black.Changed.NO)
604             self.assertEqual(len(out_lines), 0)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(
607                 unstyle(str(report)),
608                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
609                 " reformat.",
610             )
611             self.assertEqual(report.return_code, 123)
612             report.check = True
613             self.assertEqual(
614                 unstyle(str(report)),
615                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
616                 " would fail to reformat.",
617             )
618             report.check = False
619             report.diff = True
620             self.assertEqual(
621                 unstyle(str(report)),
622                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
623                 " would fail to reformat.",
624             )
625
626     def test_report_normal(self) -> None:
627         report = black.Report()
628         out_lines = []
629         err_lines = []
630
631         def out(msg: str, **kwargs: Any) -> None:
632             out_lines.append(msg)
633
634         def err(msg: str, **kwargs: Any) -> None:
635             err_lines.append(msg)
636
637         with patch("black.output._out", out), patch("black.output._err", err):
638             report.done(Path("f1"), black.Changed.NO)
639             self.assertEqual(len(out_lines), 0)
640             self.assertEqual(len(err_lines), 0)
641             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
642             self.assertEqual(report.return_code, 0)
643             report.done(Path("f2"), black.Changed.YES)
644             self.assertEqual(len(out_lines), 1)
645             self.assertEqual(len(err_lines), 0)
646             self.assertEqual(out_lines[-1], "reformatted f2")
647             self.assertEqual(
648                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
649             )
650             report.done(Path("f3"), black.Changed.CACHED)
651             self.assertEqual(len(out_lines), 1)
652             self.assertEqual(len(err_lines), 0)
653             self.assertEqual(out_lines[-1], "reformatted f2")
654             self.assertEqual(
655                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
656             )
657             self.assertEqual(report.return_code, 0)
658             report.check = True
659             self.assertEqual(report.return_code, 1)
660             report.check = False
661             report.failed(Path("e1"), "boom")
662             self.assertEqual(len(out_lines), 1)
663             self.assertEqual(len(err_lines), 1)
664             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
668                 " reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.done(Path("f3"), black.Changed.YES)
672             self.assertEqual(len(out_lines), 2)
673             self.assertEqual(len(err_lines), 1)
674             self.assertEqual(out_lines[-1], "reformatted f3")
675             self.assertEqual(
676                 unstyle(str(report)),
677                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
678                 " reformat.",
679             )
680             self.assertEqual(report.return_code, 123)
681             report.failed(Path("e2"), "boom")
682             self.assertEqual(len(out_lines), 2)
683             self.assertEqual(len(err_lines), 2)
684             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
688                 " reformat.",
689             )
690             self.assertEqual(report.return_code, 123)
691             report.path_ignored(Path("wat"), "no match")
692             self.assertEqual(len(out_lines), 2)
693             self.assertEqual(len(err_lines), 2)
694             self.assertEqual(
695                 unstyle(str(report)),
696                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
697                 " reformat.",
698             )
699             self.assertEqual(report.return_code, 123)
700             report.done(Path("f4"), black.Changed.NO)
701             self.assertEqual(len(out_lines), 2)
702             self.assertEqual(len(err_lines), 2)
703             self.assertEqual(
704                 unstyle(str(report)),
705                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
706                 " reformat.",
707             )
708             self.assertEqual(report.return_code, 123)
709             report.check = True
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
713                 " would fail to reformat.",
714             )
715             report.check = False
716             report.diff = True
717             self.assertEqual(
718                 unstyle(str(report)),
719                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
720                 " would fail to reformat.",
721             )
722
723     def test_lib2to3_parse(self) -> None:
724         with self.assertRaises(black.InvalidInput):
725             black.lib2to3_parse("invalid syntax")
726
727         straddling = "x + y"
728         black.lib2to3_parse(straddling)
729         black.lib2to3_parse(straddling, {TargetVersion.PY36})
730
731         py2_only = "print x"
732         with self.assertRaises(black.InvalidInput):
733             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
734
735         py3_only = "exec(x, end=y)"
736         black.lib2to3_parse(py3_only)
737         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
738
739     def test_get_features_used_decorator(self) -> None:
740         # Test the feature detection of new decorator syntax
741         # since this makes some test cases of test_get_features_used()
742         # fails if it fails, this is tested first so that a useful case
743         # is identified
744         simples, relaxed = read_data("decorators")
745         # skip explanation comments at the top of the file
746         for simple_test in simples.split("##")[1:]:
747             node = black.lib2to3_parse(simple_test)
748             decorator = str(node.children[0].children[0]).strip()
749             self.assertNotIn(
750                 Feature.RELAXED_DECORATORS,
751                 black.get_features_used(node),
752                 msg=(
753                     f"decorator '{decorator}' follows python<=3.8 syntax"
754                     "but is detected as 3.9+"
755                     # f"The full node is\n{node!r}"
756                 ),
757             )
758         # skip the '# output' comment at the top of the output part
759         for relaxed_test in relaxed.split("##")[1:]:
760             node = black.lib2to3_parse(relaxed_test)
761             decorator = str(node.children[0].children[0]).strip()
762             self.assertIn(
763                 Feature.RELAXED_DECORATORS,
764                 black.get_features_used(node),
765                 msg=(
766                     f"decorator '{decorator}' uses python3.9+ syntax"
767                     "but is detected as python<=3.8"
768                     # f"The full node is\n{node!r}"
769                 ),
770             )
771
772     def test_get_features_used(self) -> None:
773         node = black.lib2to3_parse("def f(*, arg): ...\n")
774         self.assertEqual(black.get_features_used(node), set())
775         node = black.lib2to3_parse("def f(*, arg,): ...\n")
776         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
777         node = black.lib2to3_parse("f(*arg,)\n")
778         self.assertEqual(
779             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
780         )
781         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
782         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
783         node = black.lib2to3_parse("123_456\n")
784         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
785         node = black.lib2to3_parse("123456\n")
786         self.assertEqual(black.get_features_used(node), set())
787         source, expected = read_data("function")
788         node = black.lib2to3_parse(source)
789         expected_features = {
790             Feature.TRAILING_COMMA_IN_CALL,
791             Feature.TRAILING_COMMA_IN_DEF,
792             Feature.F_STRINGS,
793         }
794         self.assertEqual(black.get_features_used(node), expected_features)
795         node = black.lib2to3_parse(expected)
796         self.assertEqual(black.get_features_used(node), expected_features)
797         source, expected = read_data("expression")
798         node = black.lib2to3_parse(source)
799         self.assertEqual(black.get_features_used(node), set())
800         node = black.lib2to3_parse(expected)
801         self.assertEqual(black.get_features_used(node), set())
802         node = black.lib2to3_parse("lambda a, /, b: ...")
803         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
804         node = black.lib2to3_parse("def fn(a, /, b): ...")
805         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
806         node = black.lib2to3_parse("def fn(): yield a, b")
807         self.assertEqual(black.get_features_used(node), set())
808         node = black.lib2to3_parse("def fn(): return a, b")
809         self.assertEqual(black.get_features_used(node), set())
810         node = black.lib2to3_parse("def fn(): yield *b, c")
811         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
812         node = black.lib2to3_parse("def fn(): return a, *b, c")
813         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
814         node = black.lib2to3_parse("x = a, *b, c")
815         self.assertEqual(black.get_features_used(node), set())
816         node = black.lib2to3_parse("x: Any = regular")
817         self.assertEqual(black.get_features_used(node), set())
818         node = black.lib2to3_parse("x: Any = (regular, regular)")
819         self.assertEqual(black.get_features_used(node), set())
820         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
821         self.assertEqual(black.get_features_used(node), set())
822         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
823         self.assertEqual(
824             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
825         )
826
827     def test_get_features_used_for_future_flags(self) -> None:
828         for src, features in [
829             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
830             (
831                 "from __future__ import (other, annotations)",
832                 {Feature.FUTURE_ANNOTATIONS},
833             ),
834             ("a = 1 + 2\nfrom something import annotations", set()),
835             ("from __future__ import x, y", set()),
836         ]:
837             with self.subTest(src=src, features=features):
838                 node = black.lib2to3_parse(src)
839                 future_imports = black.get_future_imports(node)
840                 self.assertEqual(
841                     black.get_features_used(node, future_imports=future_imports),
842                     features,
843                 )
844
845     def test_get_future_imports(self) -> None:
846         node = black.lib2to3_parse("\n")
847         self.assertEqual(set(), black.get_future_imports(node))
848         node = black.lib2to3_parse("from __future__ import black\n")
849         self.assertEqual({"black"}, black.get_future_imports(node))
850         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
851         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
852         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
853         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
854         node = black.lib2to3_parse(
855             "from __future__ import multiple\nfrom __future__ import imports\n"
856         )
857         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
858         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
859         self.assertEqual({"black"}, black.get_future_imports(node))
860         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
861         self.assertEqual({"black"}, black.get_future_imports(node))
862         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
863         self.assertEqual(set(), black.get_future_imports(node))
864         node = black.lib2to3_parse("from some.module import black\n")
865         self.assertEqual(set(), black.get_future_imports(node))
866         node = black.lib2to3_parse(
867             "from __future__ import unicode_literals as _unicode_literals"
868         )
869         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
870         node = black.lib2to3_parse(
871             "from __future__ import unicode_literals as _lol, print"
872         )
873         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
874
875     @pytest.mark.incompatible_with_mypyc
876     def test_debug_visitor(self) -> None:
877         source, _ = read_data("debug_visitor.py")
878         expected, _ = read_data("debug_visitor.out")
879         out_lines = []
880         err_lines = []
881
882         def out(msg: str, **kwargs: Any) -> None:
883             out_lines.append(msg)
884
885         def err(msg: str, **kwargs: Any) -> None:
886             err_lines.append(msg)
887
888         with patch("black.debug.out", out):
889             DebugVisitor.show(source)
890         actual = "\n".join(out_lines) + "\n"
891         log_name = ""
892         if expected != actual:
893             log_name = black.dump_to_file(*out_lines)
894         self.assertEqual(
895             expected,
896             actual,
897             f"AST print out is different. Actual version dumped to {log_name}",
898         )
899
900     def test_format_file_contents(self) -> None:
901         empty = ""
902         mode = DEFAULT_MODE
903         with self.assertRaises(black.NothingChanged):
904             black.format_file_contents(empty, mode=mode, fast=False)
905         just_nl = "\n"
906         with self.assertRaises(black.NothingChanged):
907             black.format_file_contents(just_nl, mode=mode, fast=False)
908         same = "j = [1, 2, 3]\n"
909         with self.assertRaises(black.NothingChanged):
910             black.format_file_contents(same, mode=mode, fast=False)
911         different = "j = [1,2,3]"
912         expected = same
913         actual = black.format_file_contents(different, mode=mode, fast=False)
914         self.assertEqual(expected, actual)
915         invalid = "return if you can"
916         with self.assertRaises(black.InvalidInput) as e:
917             black.format_file_contents(invalid, mode=mode, fast=False)
918         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
919
920     def test_endmarker(self) -> None:
921         n = black.lib2to3_parse("\n")
922         self.assertEqual(n.type, black.syms.file_input)
923         self.assertEqual(len(n.children), 1)
924         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
925
926     @pytest.mark.incompatible_with_mypyc
927     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
928     def test_assertFormatEqual(self) -> None:
929         out_lines = []
930         err_lines = []
931
932         def out(msg: str, **kwargs: Any) -> None:
933             out_lines.append(msg)
934
935         def err(msg: str, **kwargs: Any) -> None:
936             err_lines.append(msg)
937
938         with patch("black.output._out", out), patch("black.output._err", err):
939             with self.assertRaises(AssertionError):
940                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
941
942         out_str = "".join(out_lines)
943         self.assertTrue("Expected tree:" in out_str)
944         self.assertTrue("Actual tree:" in out_str)
945         self.assertEqual("".join(err_lines), "")
946
947     @event_loop()
948     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
949     def test_works_in_mono_process_only_environment(self) -> None:
950         with cache_dir() as workspace:
951             for f in [
952                 (workspace / "one.py").resolve(),
953                 (workspace / "two.py").resolve(),
954             ]:
955                 f.write_text('print("hello")\n')
956             self.invokeBlack([str(workspace)])
957
958     @event_loop()
959     def test_check_diff_use_together(self) -> None:
960         with cache_dir():
961             # Files which will be reformatted.
962             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
963             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
964             # Files which will not be reformatted.
965             src2 = (THIS_DIR / "data" / "composition.py").resolve()
966             self.invokeBlack([str(src2), "--diff", "--check"])
967             # Multi file command.
968             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
969
970     def test_no_files(self) -> None:
971         with cache_dir():
972             # Without an argument, black exits with error code 0.
973             self.invokeBlack([])
974
975     def test_broken_symlink(self) -> None:
976         with cache_dir() as workspace:
977             symlink = workspace / "broken_link.py"
978             try:
979                 symlink.symlink_to("nonexistent.py")
980             except (OSError, NotImplementedError) as e:
981                 self.skipTest(f"Can't create symlinks: {e}")
982             self.invokeBlack([str(workspace.resolve())])
983
984     def test_single_file_force_pyi(self) -> None:
985         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
986         contents, expected = read_data("force_pyi")
987         with cache_dir() as workspace:
988             path = (workspace / "file.py").resolve()
989             with open(path, "w") as fh:
990                 fh.write(contents)
991             self.invokeBlack([str(path), "--pyi"])
992             with open(path, "r") as fh:
993                 actual = fh.read()
994             # verify cache with --pyi is separate
995             pyi_cache = black.read_cache(pyi_mode)
996             self.assertIn(str(path), pyi_cache)
997             normal_cache = black.read_cache(DEFAULT_MODE)
998             self.assertNotIn(str(path), normal_cache)
999         self.assertFormatEqual(expected, actual)
1000         black.assert_equivalent(contents, actual)
1001         black.assert_stable(contents, actual, pyi_mode)
1002
1003     @event_loop()
1004     def test_multi_file_force_pyi(self) -> None:
1005         reg_mode = DEFAULT_MODE
1006         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1007         contents, expected = read_data("force_pyi")
1008         with cache_dir() as workspace:
1009             paths = [
1010                 (workspace / "file1.py").resolve(),
1011                 (workspace / "file2.py").resolve(),
1012             ]
1013             for path in paths:
1014                 with open(path, "w") as fh:
1015                     fh.write(contents)
1016             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1017             for path in paths:
1018                 with open(path, "r") as fh:
1019                     actual = fh.read()
1020                 self.assertEqual(actual, expected)
1021             # verify cache with --pyi is separate
1022             pyi_cache = black.read_cache(pyi_mode)
1023             normal_cache = black.read_cache(reg_mode)
1024             for path in paths:
1025                 self.assertIn(str(path), pyi_cache)
1026                 self.assertNotIn(str(path), normal_cache)
1027
1028     def test_pipe_force_pyi(self) -> None:
1029         source, expected = read_data("force_pyi")
1030         result = CliRunner().invoke(
1031             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1032         )
1033         self.assertEqual(result.exit_code, 0)
1034         actual = result.output
1035         self.assertFormatEqual(actual, expected)
1036
1037     def test_single_file_force_py36(self) -> None:
1038         reg_mode = DEFAULT_MODE
1039         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1040         source, expected = read_data("force_py36")
1041         with cache_dir() as workspace:
1042             path = (workspace / "file.py").resolve()
1043             with open(path, "w") as fh:
1044                 fh.write(source)
1045             self.invokeBlack([str(path), *PY36_ARGS])
1046             with open(path, "r") as fh:
1047                 actual = fh.read()
1048             # verify cache with --target-version is separate
1049             py36_cache = black.read_cache(py36_mode)
1050             self.assertIn(str(path), py36_cache)
1051             normal_cache = black.read_cache(reg_mode)
1052             self.assertNotIn(str(path), normal_cache)
1053         self.assertEqual(actual, expected)
1054
1055     @event_loop()
1056     def test_multi_file_force_py36(self) -> None:
1057         reg_mode = DEFAULT_MODE
1058         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1059         source, expected = read_data("force_py36")
1060         with cache_dir() as workspace:
1061             paths = [
1062                 (workspace / "file1.py").resolve(),
1063                 (workspace / "file2.py").resolve(),
1064             ]
1065             for path in paths:
1066                 with open(path, "w") as fh:
1067                     fh.write(source)
1068             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1069             for path in paths:
1070                 with open(path, "r") as fh:
1071                     actual = fh.read()
1072                 self.assertEqual(actual, expected)
1073             # verify cache with --target-version is separate
1074             pyi_cache = black.read_cache(py36_mode)
1075             normal_cache = black.read_cache(reg_mode)
1076             for path in paths:
1077                 self.assertIn(str(path), pyi_cache)
1078                 self.assertNotIn(str(path), normal_cache)
1079
1080     def test_pipe_force_py36(self) -> None:
1081         source, expected = read_data("force_py36")
1082         result = CliRunner().invoke(
1083             black.main,
1084             ["-", "-q", "--target-version=py36"],
1085             input=BytesIO(source.encode("utf8")),
1086         )
1087         self.assertEqual(result.exit_code, 0)
1088         actual = result.output
1089         self.assertFormatEqual(actual, expected)
1090
1091     @pytest.mark.incompatible_with_mypyc
1092     def test_reformat_one_with_stdin(self) -> None:
1093         with patch(
1094             "black.format_stdin_to_stdout",
1095             return_value=lambda *args, **kwargs: black.Changed.YES,
1096         ) as fsts:
1097             report = MagicMock()
1098             path = Path("-")
1099             black.reformat_one(
1100                 path,
1101                 fast=True,
1102                 write_back=black.WriteBack.YES,
1103                 mode=DEFAULT_MODE,
1104                 report=report,
1105             )
1106             fsts.assert_called_once()
1107             report.done.assert_called_with(path, black.Changed.YES)
1108
1109     @pytest.mark.incompatible_with_mypyc
1110     def test_reformat_one_with_stdin_filename(self) -> None:
1111         with patch(
1112             "black.format_stdin_to_stdout",
1113             return_value=lambda *args, **kwargs: black.Changed.YES,
1114         ) as fsts:
1115             report = MagicMock()
1116             p = "foo.py"
1117             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1118             expected = Path(p)
1119             black.reformat_one(
1120                 path,
1121                 fast=True,
1122                 write_back=black.WriteBack.YES,
1123                 mode=DEFAULT_MODE,
1124                 report=report,
1125             )
1126             fsts.assert_called_once_with(
1127                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1128             )
1129             # __BLACK_STDIN_FILENAME__ should have been stripped
1130             report.done.assert_called_with(expected, black.Changed.YES)
1131
1132     @pytest.mark.incompatible_with_mypyc
1133     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1134         with patch(
1135             "black.format_stdin_to_stdout",
1136             return_value=lambda *args, **kwargs: black.Changed.YES,
1137         ) as fsts:
1138             report = MagicMock()
1139             p = "foo.pyi"
1140             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1141             expected = Path(p)
1142             black.reformat_one(
1143                 path,
1144                 fast=True,
1145                 write_back=black.WriteBack.YES,
1146                 mode=DEFAULT_MODE,
1147                 report=report,
1148             )
1149             fsts.assert_called_once_with(
1150                 fast=True,
1151                 write_back=black.WriteBack.YES,
1152                 mode=replace(DEFAULT_MODE, is_pyi=True),
1153             )
1154             # __BLACK_STDIN_FILENAME__ should have been stripped
1155             report.done.assert_called_with(expected, black.Changed.YES)
1156
1157     @pytest.mark.incompatible_with_mypyc
1158     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1159         with patch(
1160             "black.format_stdin_to_stdout",
1161             return_value=lambda *args, **kwargs: black.Changed.YES,
1162         ) as fsts:
1163             report = MagicMock()
1164             p = "foo.ipynb"
1165             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1166             expected = Path(p)
1167             black.reformat_one(
1168                 path,
1169                 fast=True,
1170                 write_back=black.WriteBack.YES,
1171                 mode=DEFAULT_MODE,
1172                 report=report,
1173             )
1174             fsts.assert_called_once_with(
1175                 fast=True,
1176                 write_back=black.WriteBack.YES,
1177                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1178             )
1179             # __BLACK_STDIN_FILENAME__ should have been stripped
1180             report.done.assert_called_with(expected, black.Changed.YES)
1181
1182     @pytest.mark.incompatible_with_mypyc
1183     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1184         with patch(
1185             "black.format_stdin_to_stdout",
1186             return_value=lambda *args, **kwargs: black.Changed.YES,
1187         ) as fsts:
1188             report = MagicMock()
1189             # Even with an existing file, since we are forcing stdin, black
1190             # should output to stdout and not modify the file inplace
1191             p = Path(str(THIS_DIR / "data/collections.py"))
1192             # Make sure is_file actually returns True
1193             self.assertTrue(p.is_file())
1194             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1195             expected = Path(p)
1196             black.reformat_one(
1197                 path,
1198                 fast=True,
1199                 write_back=black.WriteBack.YES,
1200                 mode=DEFAULT_MODE,
1201                 report=report,
1202             )
1203             fsts.assert_called_once()
1204             # __BLACK_STDIN_FILENAME__ should have been stripped
1205             report.done.assert_called_with(expected, black.Changed.YES)
1206
1207     def test_reformat_one_with_stdin_empty(self) -> None:
1208         output = io.StringIO()
1209         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1210             try:
1211                 black.format_stdin_to_stdout(
1212                     fast=True,
1213                     content="",
1214                     write_back=black.WriteBack.YES,
1215                     mode=DEFAULT_MODE,
1216                 )
1217             except io.UnsupportedOperation:
1218                 pass  # StringIO does not support detach
1219             assert output.getvalue() == ""
1220
1221     def test_invalid_cli_regex(self) -> None:
1222         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1223             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1224
1225     def test_required_version_matches_version(self) -> None:
1226         self.invokeBlack(
1227             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1228         )
1229
1230     def test_required_version_does_not_match_version(self) -> None:
1231         self.invokeBlack(
1232             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1233         )
1234
1235     def test_preserves_line_endings(self) -> None:
1236         with TemporaryDirectory() as workspace:
1237             test_file = Path(workspace) / "test.py"
1238             for nl in ["\n", "\r\n"]:
1239                 contents = nl.join(["def f(  ):", "    pass"])
1240                 test_file.write_bytes(contents.encode())
1241                 ff(test_file, write_back=black.WriteBack.YES)
1242                 updated_contents: bytes = test_file.read_bytes()
1243                 self.assertIn(nl.encode(), updated_contents)
1244                 if nl == "\n":
1245                     self.assertNotIn(b"\r\n", updated_contents)
1246
1247     def test_preserves_line_endings_via_stdin(self) -> None:
1248         for nl in ["\n", "\r\n"]:
1249             contents = nl.join(["def f(  ):", "    pass"])
1250             runner = BlackRunner()
1251             result = runner.invoke(
1252                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1253             )
1254             self.assertEqual(result.exit_code, 0)
1255             output = result.stdout_bytes
1256             self.assertIn(nl.encode("utf8"), output)
1257             if nl == "\n":
1258                 self.assertNotIn(b"\r\n", output)
1259
1260     def test_assert_equivalent_different_asts(self) -> None:
1261         with self.assertRaises(AssertionError):
1262             black.assert_equivalent("{}", "None")
1263
1264     def test_shhh_click(self) -> None:
1265         try:
1266             from click import _unicodefun
1267         except ModuleNotFoundError:
1268             self.skipTest("Incompatible Click version")
1269         if not hasattr(_unicodefun, "_verify_python3_env"):
1270             self.skipTest("Incompatible Click version")
1271         # First, let's see if Click is crashing with a preferred ASCII charset.
1272         with patch("locale.getpreferredencoding") as gpe:
1273             gpe.return_value = "ASCII"
1274             with self.assertRaises(RuntimeError):
1275                 _unicodefun._verify_python3_env()  # type: ignore
1276         # Now, let's silence Click...
1277         black.patch_click()
1278         # ...and confirm it's silent.
1279         with patch("locale.getpreferredencoding") as gpe:
1280             gpe.return_value = "ASCII"
1281             try:
1282                 _unicodefun._verify_python3_env()  # type: ignore
1283             except RuntimeError as re:
1284                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1285
1286     def test_root_logger_not_used_directly(self) -> None:
1287         def fail(*args: Any, **kwargs: Any) -> None:
1288             self.fail("Record created with root logger")
1289
1290         with patch.multiple(
1291             logging.root,
1292             debug=fail,
1293             info=fail,
1294             warning=fail,
1295             error=fail,
1296             critical=fail,
1297             log=fail,
1298         ):
1299             ff(THIS_DIR / "util.py")
1300
1301     def test_invalid_config_return_code(self) -> None:
1302         tmp_file = Path(black.dump_to_file())
1303         try:
1304             tmp_config = Path(black.dump_to_file())
1305             tmp_config.unlink()
1306             args = ["--config", str(tmp_config), str(tmp_file)]
1307             self.invokeBlack(args, exit_code=2, ignore_config=False)
1308         finally:
1309             tmp_file.unlink()
1310
1311     def test_parse_pyproject_toml(self) -> None:
1312         test_toml_file = THIS_DIR / "test.toml"
1313         config = black.parse_pyproject_toml(str(test_toml_file))
1314         self.assertEqual(config["verbose"], 1)
1315         self.assertEqual(config["check"], "no")
1316         self.assertEqual(config["diff"], "y")
1317         self.assertEqual(config["color"], True)
1318         self.assertEqual(config["line_length"], 79)
1319         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1320         self.assertEqual(config["exclude"], r"\.pyi?$")
1321         self.assertEqual(config["include"], r"\.py?$")
1322
1323     def test_read_pyproject_toml(self) -> None:
1324         test_toml_file = THIS_DIR / "test.toml"
1325         fake_ctx = FakeContext()
1326         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1327         config = fake_ctx.default_map
1328         self.assertEqual(config["verbose"], "1")
1329         self.assertEqual(config["check"], "no")
1330         self.assertEqual(config["diff"], "y")
1331         self.assertEqual(config["color"], "True")
1332         self.assertEqual(config["line_length"], "79")
1333         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1334         self.assertEqual(config["exclude"], r"\.pyi?$")
1335         self.assertEqual(config["include"], r"\.py?$")
1336
1337     @pytest.mark.incompatible_with_mypyc
1338     def test_find_project_root(self) -> None:
1339         with TemporaryDirectory() as workspace:
1340             root = Path(workspace)
1341             test_dir = root / "test"
1342             test_dir.mkdir()
1343
1344             src_dir = root / "src"
1345             src_dir.mkdir()
1346
1347             root_pyproject = root / "pyproject.toml"
1348             root_pyproject.touch()
1349             src_pyproject = src_dir / "pyproject.toml"
1350             src_pyproject.touch()
1351             src_python = src_dir / "foo.py"
1352             src_python.touch()
1353
1354             self.assertEqual(
1355                 black.find_project_root((src_dir, test_dir)),
1356                 (root.resolve(), "pyproject.toml"),
1357             )
1358             self.assertEqual(
1359                 black.find_project_root((src_dir,)),
1360                 (src_dir.resolve(), "pyproject.toml"),
1361             )
1362             self.assertEqual(
1363                 black.find_project_root((src_python,)),
1364                 (src_dir.resolve(), "pyproject.toml"),
1365             )
1366
1367     @patch(
1368         "black.files.find_user_pyproject_toml",
1369         black.files.find_user_pyproject_toml.__wrapped__,
1370     )
1371     def test_find_user_pyproject_toml_linux(self) -> None:
1372         if system() == "Windows":
1373             return
1374
1375         # Test if XDG_CONFIG_HOME is checked
1376         with TemporaryDirectory() as workspace:
1377             tmp_user_config = Path(workspace) / "black"
1378             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1379                 self.assertEqual(
1380                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1381                 )
1382
1383         # Test fallback for XDG_CONFIG_HOME
1384         with patch.dict("os.environ"):
1385             os.environ.pop("XDG_CONFIG_HOME", None)
1386             fallback_user_config = Path("~/.config").expanduser() / "black"
1387             self.assertEqual(
1388                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1389             )
1390
1391     def test_find_user_pyproject_toml_windows(self) -> None:
1392         if system() != "Windows":
1393             return
1394
1395         user_config_path = Path.home() / ".black"
1396         self.assertEqual(
1397             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1398         )
1399
1400     def test_bpo_33660_workaround(self) -> None:
1401         if system() == "Windows":
1402             return
1403
1404         # https://bugs.python.org/issue33660
1405         root = Path("/")
1406         with change_directory(root):
1407             path = Path("workspace") / "project"
1408             report = black.Report(verbose=True)
1409             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1410             self.assertEqual(normalized_path, "workspace/project")
1411
1412     def test_newline_comment_interaction(self) -> None:
1413         source = "class A:\\\r\n# type: ignore\n pass\n"
1414         output = black.format_str(source, mode=DEFAULT_MODE)
1415         black.assert_stable(source, output, mode=DEFAULT_MODE)
1416
1417     def test_bpo_2142_workaround(self) -> None:
1418
1419         # https://bugs.python.org/issue2142
1420
1421         source, _ = read_data("missing_final_newline.py")
1422         # read_data adds a trailing newline
1423         source = source.rstrip()
1424         expected, _ = read_data("missing_final_newline.diff")
1425         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1426         diff_header = re.compile(
1427             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1428             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1429         )
1430         try:
1431             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1432             self.assertEqual(result.exit_code, 0)
1433         finally:
1434             os.unlink(tmp_file)
1435         actual = result.output
1436         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1437         self.assertEqual(actual, expected)
1438
1439     @staticmethod
1440     def compare_results(
1441         result: click.testing.Result, expected_value: str, expected_exit_code: int
1442     ) -> None:
1443         """Helper method to test the value and exit code of a click Result."""
1444         assert (
1445             result.output == expected_value
1446         ), "The output did not match the expected value."
1447         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1448
1449     def test_code_option(self) -> None:
1450         """Test the code option with no changes."""
1451         code = 'print("Hello world")\n'
1452         args = ["--code", code]
1453         result = CliRunner().invoke(black.main, args)
1454
1455         self.compare_results(result, code, 0)
1456
1457     def test_code_option_changed(self) -> None:
1458         """Test the code option when changes are required."""
1459         code = "print('hello world')"
1460         formatted = black.format_str(code, mode=DEFAULT_MODE)
1461
1462         args = ["--code", code]
1463         result = CliRunner().invoke(black.main, args)
1464
1465         self.compare_results(result, formatted, 0)
1466
1467     def test_code_option_check(self) -> None:
1468         """Test the code option when check is passed."""
1469         args = ["--check", "--code", 'print("Hello world")\n']
1470         result = CliRunner().invoke(black.main, args)
1471         self.compare_results(result, "", 0)
1472
1473     def test_code_option_check_changed(self) -> None:
1474         """Test the code option when changes are required, and check is passed."""
1475         args = ["--check", "--code", "print('hello world')"]
1476         result = CliRunner().invoke(black.main, args)
1477         self.compare_results(result, "", 1)
1478
1479     def test_code_option_diff(self) -> None:
1480         """Test the code option when diff is passed."""
1481         code = "print('hello world')"
1482         formatted = black.format_str(code, mode=DEFAULT_MODE)
1483         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1484
1485         args = ["--diff", "--code", code]
1486         result = CliRunner().invoke(black.main, args)
1487
1488         # Remove time from diff
1489         output = DIFF_TIME.sub("", result.output)
1490
1491         assert output == result_diff, "The output did not match the expected value."
1492         assert result.exit_code == 0, "The exit code is incorrect."
1493
1494     def test_code_option_color_diff(self) -> None:
1495         """Test the code option when color and diff are passed."""
1496         code = "print('hello world')"
1497         formatted = black.format_str(code, mode=DEFAULT_MODE)
1498
1499         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1500         result_diff = color_diff(result_diff)
1501
1502         args = ["--diff", "--color", "--code", code]
1503         result = CliRunner().invoke(black.main, args)
1504
1505         # Remove time from diff
1506         output = DIFF_TIME.sub("", result.output)
1507
1508         assert output == result_diff, "The output did not match the expected value."
1509         assert result.exit_code == 0, "The exit code is incorrect."
1510
1511     @pytest.mark.incompatible_with_mypyc
1512     def test_code_option_safe(self) -> None:
1513         """Test that the code option throws an error when the sanity checks fail."""
1514         # Patch black.assert_equivalent to ensure the sanity checks fail
1515         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1516             code = 'print("Hello world")'
1517             error_msg = f"{code}\nerror: cannot format <string>: \n"
1518
1519             args = ["--safe", "--code", code]
1520             result = CliRunner().invoke(black.main, args)
1521
1522             self.compare_results(result, error_msg, 123)
1523
1524     def test_code_option_fast(self) -> None:
1525         """Test that the code option ignores errors when the sanity checks fail."""
1526         # Patch black.assert_equivalent to ensure the sanity checks fail
1527         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1528             code = 'print("Hello world")'
1529             formatted = black.format_str(code, mode=DEFAULT_MODE)
1530
1531             args = ["--fast", "--code", code]
1532             result = CliRunner().invoke(black.main, args)
1533
1534             self.compare_results(result, formatted, 0)
1535
1536     @pytest.mark.incompatible_with_mypyc
1537     def test_code_option_config(self) -> None:
1538         """
1539         Test that the code option finds the pyproject.toml in the current directory.
1540         """
1541         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1542             args = ["--code", "print"]
1543             # This is the only directory known to contain a pyproject.toml
1544             with change_directory(PROJECT_ROOT):
1545                 CliRunner().invoke(black.main, args)
1546                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1547
1548             assert (
1549                 len(parse.mock_calls) >= 1
1550             ), "Expected config parse to be called with the current directory."
1551
1552             _, call_args, _ = parse.mock_calls[0]
1553             assert (
1554                 call_args[0].lower() == str(pyproject_path).lower()
1555             ), "Incorrect config loaded."
1556
1557     @pytest.mark.incompatible_with_mypyc
1558     def test_code_option_parent_config(self) -> None:
1559         """
1560         Test that the code option finds the pyproject.toml in the parent directory.
1561         """
1562         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1563             with change_directory(THIS_DIR):
1564                 args = ["--code", "print"]
1565                 CliRunner().invoke(black.main, args)
1566
1567                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1568                 assert (
1569                     len(parse.mock_calls) >= 1
1570                 ), "Expected config parse to be called with the current directory."
1571
1572                 _, call_args, _ = parse.mock_calls[0]
1573                 assert (
1574                     call_args[0].lower() == str(pyproject_path).lower()
1575                 ), "Incorrect config loaded."
1576
1577     def test_for_handled_unexpected_eof_error(self) -> None:
1578         """
1579         Test that an unexpected EOF SyntaxError is nicely presented.
1580         """
1581         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1582             black.lib2to3_parse("print(", {})
1583
1584         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1585
1586     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1587         with pytest.raises(AssertionError) as err:
1588             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1589
1590         err.match("--safe")
1591         # Unfortunately the SyntaxError message has changed in newer versions so we
1592         # can't match it directly.
1593         err.match("invalid character")
1594         err.match(r"\(<unknown>, line 1\)")
1595
1596
1597 class TestCaching:
1598     def test_cache_broken_file(self) -> None:
1599         mode = DEFAULT_MODE
1600         with cache_dir() as workspace:
1601             cache_file = get_cache_file(mode)
1602             cache_file.write_text("this is not a pickle")
1603             assert black.read_cache(mode) == {}
1604             src = (workspace / "test.py").resolve()
1605             src.write_text("print('hello')")
1606             invokeBlack([str(src)])
1607             cache = black.read_cache(mode)
1608             assert str(src) in cache
1609
1610     def test_cache_single_file_already_cached(self) -> None:
1611         mode = DEFAULT_MODE
1612         with cache_dir() as workspace:
1613             src = (workspace / "test.py").resolve()
1614             src.write_text("print('hello')")
1615             black.write_cache({}, [src], mode)
1616             invokeBlack([str(src)])
1617             assert src.read_text() == "print('hello')"
1618
1619     @event_loop()
1620     def test_cache_multiple_files(self) -> None:
1621         mode = DEFAULT_MODE
1622         with cache_dir() as workspace, patch(
1623             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1624         ):
1625             one = (workspace / "one.py").resolve()
1626             with one.open("w") as fobj:
1627                 fobj.write("print('hello')")
1628             two = (workspace / "two.py").resolve()
1629             with two.open("w") as fobj:
1630                 fobj.write("print('hello')")
1631             black.write_cache({}, [one], mode)
1632             invokeBlack([str(workspace)])
1633             with one.open("r") as fobj:
1634                 assert fobj.read() == "print('hello')"
1635             with two.open("r") as fobj:
1636                 assert fobj.read() == 'print("hello")\n'
1637             cache = black.read_cache(mode)
1638             assert str(one) in cache
1639             assert str(two) in cache
1640
1641     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1642     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1643         mode = DEFAULT_MODE
1644         with cache_dir() as workspace:
1645             src = (workspace / "test.py").resolve()
1646             with src.open("w") as fobj:
1647                 fobj.write("print('hello')")
1648             with patch("black.read_cache") as read_cache, patch(
1649                 "black.write_cache"
1650             ) as write_cache:
1651                 cmd = [str(src), "--diff"]
1652                 if color:
1653                     cmd.append("--color")
1654                 invokeBlack(cmd)
1655                 cache_file = get_cache_file(mode)
1656                 assert cache_file.exists() is False
1657                 write_cache.assert_not_called()
1658                 read_cache.assert_not_called()
1659
1660     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1661     @event_loop()
1662     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1663         with cache_dir() as workspace:
1664             for tag in range(0, 4):
1665                 src = (workspace / f"test{tag}.py").resolve()
1666                 with src.open("w") as fobj:
1667                     fobj.write("print('hello')")
1668             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1669                 cmd = ["--diff", str(workspace)]
1670                 if color:
1671                     cmd.append("--color")
1672                 invokeBlack(cmd, exit_code=0)
1673                 # this isn't quite doing what we want, but if it _isn't_
1674                 # called then we cannot be using the lock it provides
1675                 mgr.assert_called()
1676
1677     def test_no_cache_when_stdin(self) -> None:
1678         mode = DEFAULT_MODE
1679         with cache_dir():
1680             result = CliRunner().invoke(
1681                 black.main, ["-"], input=BytesIO(b"print('hello')")
1682             )
1683             assert not result.exit_code
1684             cache_file = get_cache_file(mode)
1685             assert not cache_file.exists()
1686
1687     def test_read_cache_no_cachefile(self) -> None:
1688         mode = DEFAULT_MODE
1689         with cache_dir():
1690             assert black.read_cache(mode) == {}
1691
1692     def test_write_cache_read_cache(self) -> None:
1693         mode = DEFAULT_MODE
1694         with cache_dir() as workspace:
1695             src = (workspace / "test.py").resolve()
1696             src.touch()
1697             black.write_cache({}, [src], mode)
1698             cache = black.read_cache(mode)
1699             assert str(src) in cache
1700             assert cache[str(src)] == black.get_cache_info(src)
1701
1702     def test_filter_cached(self) -> None:
1703         with TemporaryDirectory() as workspace:
1704             path = Path(workspace)
1705             uncached = (path / "uncached").resolve()
1706             cached = (path / "cached").resolve()
1707             cached_but_changed = (path / "changed").resolve()
1708             uncached.touch()
1709             cached.touch()
1710             cached_but_changed.touch()
1711             cache = {
1712                 str(cached): black.get_cache_info(cached),
1713                 str(cached_but_changed): (0.0, 0),
1714             }
1715             todo, done = black.filter_cached(
1716                 cache, {uncached, cached, cached_but_changed}
1717             )
1718             assert todo == {uncached, cached_but_changed}
1719             assert done == {cached}
1720
1721     def test_write_cache_creates_directory_if_needed(self) -> None:
1722         mode = DEFAULT_MODE
1723         with cache_dir(exists=False) as workspace:
1724             assert not workspace.exists()
1725             black.write_cache({}, [], mode)
1726             assert workspace.exists()
1727
1728     @event_loop()
1729     def test_failed_formatting_does_not_get_cached(self) -> None:
1730         mode = DEFAULT_MODE
1731         with cache_dir() as workspace, patch(
1732             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1733         ):
1734             failing = (workspace / "failing.py").resolve()
1735             with failing.open("w") as fobj:
1736                 fobj.write("not actually python")
1737             clean = (workspace / "clean.py").resolve()
1738             with clean.open("w") as fobj:
1739                 fobj.write('print("hello")\n')
1740             invokeBlack([str(workspace)], exit_code=123)
1741             cache = black.read_cache(mode)
1742             assert str(failing) not in cache
1743             assert str(clean) in cache
1744
1745     def test_write_cache_write_fail(self) -> None:
1746         mode = DEFAULT_MODE
1747         with cache_dir(), patch.object(Path, "open") as mock:
1748             mock.side_effect = OSError
1749             black.write_cache({}, [], mode)
1750
1751     def test_read_cache_line_lengths(self) -> None:
1752         mode = DEFAULT_MODE
1753         short_mode = replace(DEFAULT_MODE, line_length=1)
1754         with cache_dir() as workspace:
1755             path = (workspace / "file.py").resolve()
1756             path.touch()
1757             black.write_cache({}, [path], mode)
1758             one = black.read_cache(mode)
1759             assert str(path) in one
1760             two = black.read_cache(short_mode)
1761             assert str(path) not in two
1762
1763
1764 def assert_collected_sources(
1765     src: Sequence[Union[str, Path]],
1766     expected: Sequence[Union[str, Path]],
1767     *,
1768     ctx: Optional[FakeContext] = None,
1769     exclude: Optional[str] = None,
1770     include: Optional[str] = None,
1771     extend_exclude: Optional[str] = None,
1772     force_exclude: Optional[str] = None,
1773     stdin_filename: Optional[str] = None,
1774 ) -> None:
1775     gs_src = tuple(str(Path(s)) for s in src)
1776     gs_expected = [Path(s) for s in expected]
1777     gs_exclude = None if exclude is None else compile_pattern(exclude)
1778     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1779     gs_extend_exclude = (
1780         None if extend_exclude is None else compile_pattern(extend_exclude)
1781     )
1782     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1783     collected = black.get_sources(
1784         ctx=ctx or FakeContext(),
1785         src=gs_src,
1786         quiet=False,
1787         verbose=False,
1788         include=gs_include,
1789         exclude=gs_exclude,
1790         extend_exclude=gs_extend_exclude,
1791         force_exclude=gs_force_exclude,
1792         report=black.Report(),
1793         stdin_filename=stdin_filename,
1794     )
1795     assert sorted(collected) == sorted(gs_expected)
1796
1797
1798 class TestFileCollection:
1799     def test_include_exclude(self) -> None:
1800         path = THIS_DIR / "data" / "include_exclude_tests"
1801         src = [path]
1802         expected = [
1803             Path(path / "b/dont_exclude/a.py"),
1804             Path(path / "b/dont_exclude/a.pyi"),
1805         ]
1806         assert_collected_sources(
1807             src,
1808             expected,
1809             include=r"\.pyi?$",
1810             exclude=r"/exclude/|/\.definitely_exclude/",
1811         )
1812
1813     def test_gitignore_used_as_default(self) -> None:
1814         base = Path(DATA_DIR / "include_exclude_tests")
1815         expected = [
1816             base / "b/.definitely_exclude/a.py",
1817             base / "b/.definitely_exclude/a.pyi",
1818         ]
1819         src = [base / "b/"]
1820         ctx = FakeContext()
1821         ctx.obj["root"] = base
1822         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1823
1824     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1825     def test_exclude_for_issue_1572(self) -> None:
1826         # Exclude shouldn't touch files that were explicitly given to Black through the
1827         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1828         # https://github.com/psf/black/issues/1572
1829         path = DATA_DIR / "include_exclude_tests"
1830         src = [path / "b/exclude/a.py"]
1831         expected = [path / "b/exclude/a.py"]
1832         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1833
1834     def test_gitignore_exclude(self) -> None:
1835         path = THIS_DIR / "data" / "include_exclude_tests"
1836         include = re.compile(r"\.pyi?$")
1837         exclude = re.compile(r"")
1838         report = black.Report()
1839         gitignore = PathSpec.from_lines(
1840             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1841         )
1842         sources: List[Path] = []
1843         expected = [
1844             Path(path / "b/dont_exclude/a.py"),
1845             Path(path / "b/dont_exclude/a.pyi"),
1846         ]
1847         this_abs = THIS_DIR.resolve()
1848         sources.extend(
1849             black.gen_python_files(
1850                 path.iterdir(),
1851                 this_abs,
1852                 include,
1853                 exclude,
1854                 None,
1855                 None,
1856                 report,
1857                 gitignore,
1858                 verbose=False,
1859                 quiet=False,
1860             )
1861         )
1862         assert sorted(expected) == sorted(sources)
1863
1864     def test_nested_gitignore(self) -> None:
1865         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1866         include = re.compile(r"\.pyi?$")
1867         exclude = re.compile(r"")
1868         root_gitignore = black.files.get_gitignore(path)
1869         report = black.Report()
1870         expected: List[Path] = [
1871             Path(path / "x.py"),
1872             Path(path / "root/b.py"),
1873             Path(path / "root/c.py"),
1874             Path(path / "root/child/c.py"),
1875         ]
1876         this_abs = THIS_DIR.resolve()
1877         sources = list(
1878             black.gen_python_files(
1879                 path.iterdir(),
1880                 this_abs,
1881                 include,
1882                 exclude,
1883                 None,
1884                 None,
1885                 report,
1886                 root_gitignore,
1887                 verbose=False,
1888                 quiet=False,
1889             )
1890         )
1891         assert sorted(expected) == sorted(sources)
1892
1893     def test_invalid_gitignore(self) -> None:
1894         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1895         empty_config = path / "pyproject.toml"
1896         result = BlackRunner().invoke(
1897             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1898         )
1899         assert result.exit_code == 1
1900         assert result.stderr_bytes is not None
1901
1902         gitignore = path / ".gitignore"
1903         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1904
1905     def test_invalid_nested_gitignore(self) -> None:
1906         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1907         empty_config = path / "pyproject.toml"
1908         result = BlackRunner().invoke(
1909             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1910         )
1911         assert result.exit_code == 1
1912         assert result.stderr_bytes is not None
1913
1914         gitignore = path / "a" / ".gitignore"
1915         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1916
1917     def test_empty_include(self) -> None:
1918         path = DATA_DIR / "include_exclude_tests"
1919         src = [path]
1920         expected = [
1921             Path(path / "b/exclude/a.pie"),
1922             Path(path / "b/exclude/a.py"),
1923             Path(path / "b/exclude/a.pyi"),
1924             Path(path / "b/dont_exclude/a.pie"),
1925             Path(path / "b/dont_exclude/a.py"),
1926             Path(path / "b/dont_exclude/a.pyi"),
1927             Path(path / "b/.definitely_exclude/a.pie"),
1928             Path(path / "b/.definitely_exclude/a.py"),
1929             Path(path / "b/.definitely_exclude/a.pyi"),
1930             Path(path / ".gitignore"),
1931             Path(path / "pyproject.toml"),
1932         ]
1933         # Setting exclude explicitly to an empty string to block .gitignore usage.
1934         assert_collected_sources(src, expected, include="", exclude="")
1935
1936     def test_extend_exclude(self) -> None:
1937         path = DATA_DIR / "include_exclude_tests"
1938         src = [path]
1939         expected = [
1940             Path(path / "b/exclude/a.py"),
1941             Path(path / "b/dont_exclude/a.py"),
1942         ]
1943         assert_collected_sources(
1944             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1945         )
1946
1947     @pytest.mark.incompatible_with_mypyc
1948     def test_symlink_out_of_root_directory(self) -> None:
1949         path = MagicMock()
1950         root = THIS_DIR.resolve()
1951         child = MagicMock()
1952         include = re.compile(black.DEFAULT_INCLUDES)
1953         exclude = re.compile(black.DEFAULT_EXCLUDES)
1954         report = black.Report()
1955         gitignore = PathSpec.from_lines("gitwildmatch", [])
1956         # `child` should behave like a symlink which resolved path is clearly
1957         # outside of the `root` directory.
1958         path.iterdir.return_value = [child]
1959         child.resolve.return_value = Path("/a/b/c")
1960         child.as_posix.return_value = "/a/b/c"
1961         child.is_symlink.return_value = True
1962         try:
1963             list(
1964                 black.gen_python_files(
1965                     path.iterdir(),
1966                     root,
1967                     include,
1968                     exclude,
1969                     None,
1970                     None,
1971                     report,
1972                     gitignore,
1973                     verbose=False,
1974                     quiet=False,
1975                 )
1976             )
1977         except ValueError as ve:
1978             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1979         path.iterdir.assert_called_once()
1980         child.resolve.assert_called_once()
1981         child.is_symlink.assert_called_once()
1982         # `child` should behave like a strange file which resolved path is clearly
1983         # outside of the `root` directory.
1984         child.is_symlink.return_value = False
1985         with pytest.raises(ValueError):
1986             list(
1987                 black.gen_python_files(
1988                     path.iterdir(),
1989                     root,
1990                     include,
1991                     exclude,
1992                     None,
1993                     None,
1994                     report,
1995                     gitignore,
1996                     verbose=False,
1997                     quiet=False,
1998                 )
1999             )
2000         path.iterdir.assert_called()
2001         assert path.iterdir.call_count == 2
2002         child.resolve.assert_called()
2003         assert child.resolve.call_count == 2
2004         child.is_symlink.assert_called()
2005         assert child.is_symlink.call_count == 2
2006
2007     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2008     def test_get_sources_with_stdin(self) -> None:
2009         src = ["-"]
2010         expected = ["-"]
2011         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2012
2013     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2014     def test_get_sources_with_stdin_filename(self) -> None:
2015         src = ["-"]
2016         stdin_filename = str(THIS_DIR / "data/collections.py")
2017         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2018         assert_collected_sources(
2019             src,
2020             expected,
2021             exclude=r"/exclude/a\.py",
2022             stdin_filename=stdin_filename,
2023         )
2024
2025     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2026     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2027         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2028         # file being passed directly. This is the same as
2029         # test_exclude_for_issue_1572
2030         path = DATA_DIR / "include_exclude_tests"
2031         src = ["-"]
2032         stdin_filename = str(path / "b/exclude/a.py")
2033         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2034         assert_collected_sources(
2035             src,
2036             expected,
2037             exclude=r"/exclude/|a\.py",
2038             stdin_filename=stdin_filename,
2039         )
2040
2041     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2042     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2043         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2044         # file being passed directly. This is the same as
2045         # test_exclude_for_issue_1572
2046         src = ["-"]
2047         path = THIS_DIR / "data" / "include_exclude_tests"
2048         stdin_filename = str(path / "b/exclude/a.py")
2049         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2050         assert_collected_sources(
2051             src,
2052             expected,
2053             extend_exclude=r"/exclude/|a\.py",
2054             stdin_filename=stdin_filename,
2055         )
2056
2057     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2058     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2059         # Force exclude should exclude the file when passing it through
2060         # stdin_filename
2061         path = THIS_DIR / "data" / "include_exclude_tests"
2062         stdin_filename = str(path / "b/exclude/a.py")
2063         assert_collected_sources(
2064             src=["-"],
2065             expected=[],
2066             force_exclude=r"/exclude/|a\.py",
2067             stdin_filename=stdin_filename,
2068         )
2069
2070
2071 try:
2072     with open(black.__file__, "r", encoding="utf-8") as _bf:
2073         black_source_lines = _bf.readlines()
2074 except UnicodeDecodeError:
2075     if not black.COMPILED:
2076         raise
2077
2078
2079 def tracefunc(
2080     frame: types.FrameType, event: str, arg: Any
2081 ) -> Callable[[types.FrameType, str, Any], Any]:
2082     """Show function calls `from black/__init__.py` as they happen.
2083
2084     Register this with `sys.settrace()` in a test you're debugging.
2085     """
2086     if event != "call":
2087         return tracefunc
2088
2089     stack = len(inspect.stack()) - 19
2090     stack *= 2
2091     filename = frame.f_code.co_filename
2092     lineno = frame.f_lineno
2093     func_sig_lineno = lineno - 1
2094     funcname = black_source_lines[func_sig_lineno].strip()
2095     while funcname.startswith("@"):
2096         func_sig_lineno += 1
2097         funcname = black_source_lines[func_sig_lineno].strip()
2098     if "black/__init__.py" in filename:
2099         print(f"{' ' * stack}{lineno}:{funcname}")
2100     return tracefunc