]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Deprecate ESP and move the functionality under --preview (#2789)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103         # Dummy root, since most of the tests don't care about it
104         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
105
106
107 class FakeParameter(click.Parameter):
108     """A fake click Parameter for when calling functions that need it."""
109
110     def __init__(self) -> None:
111         pass
112
113
114 class BlackRunner(CliRunner):
115     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
116
117     def __init__(self) -> None:
118         super().__init__(mix_stderr=False)
119
120
121 def invokeBlack(
122     args: List[str], exit_code: int = 0, ignore_config: bool = True
123 ) -> None:
124     runner = BlackRunner()
125     if ignore_config:
126         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
127     result = runner.invoke(black.main, args, catch_exceptions=False)
128     assert result.stdout_bytes is not None
129     assert result.stderr_bytes is not None
130     msg = (
131         f"Failed with args: {args}\n"
132         f"stdout: {result.stdout_bytes.decode()!r}\n"
133         f"stderr: {result.stderr_bytes.decode()!r}\n"
134         f"exception: {result.exception}"
135     )
136     assert result.exit_code == exit_code, msg
137
138
139 class BlackTestCase(BlackBaseTestCase):
140     invokeBlack = staticmethod(invokeBlack)
141
142     def test_empty_ff(self) -> None:
143         expected = ""
144         tmp_file = Path(black.dump_to_file())
145         try:
146             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
147             with open(tmp_file, encoding="utf8") as f:
148                 actual = f.read()
149         finally:
150             os.unlink(tmp_file)
151         self.assertFormatEqual(expected, actual)
152
153     def test_experimental_string_processing_warns(self) -> None:
154         self.assertWarns(
155             black.mode.Deprecated, black.Mode, experimental_string_processing=True
156         )
157
158     def test_piping(self) -> None:
159         source, expected = read_data("src/black/__init__", data=False)
160         result = BlackRunner().invoke(
161             black.main,
162             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
163             input=BytesIO(source.encode("utf8")),
164         )
165         self.assertEqual(result.exit_code, 0)
166         self.assertFormatEqual(expected, result.output)
167         if source != result.output:
168             black.assert_equivalent(source, result.output)
169             black.assert_stable(source, result.output, DEFAULT_MODE)
170
171     def test_piping_diff(self) -> None:
172         diff_header = re.compile(
173             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
174             r"\+\d\d\d\d"
175         )
176         source, _ = read_data("expression.py")
177         expected, _ = read_data("expression.diff")
178         config = THIS_DIR / "data" / "empty_pyproject.toml"
179         args = [
180             "-",
181             "--fast",
182             f"--line-length={black.DEFAULT_LINE_LENGTH}",
183             "--diff",
184             f"--config={config}",
185         ]
186         result = BlackRunner().invoke(
187             black.main, args, input=BytesIO(source.encode("utf8"))
188         )
189         self.assertEqual(result.exit_code, 0)
190         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
191         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
192         self.assertEqual(expected, actual)
193
194     def test_piping_diff_with_color(self) -> None:
195         source, _ = read_data("expression.py")
196         config = THIS_DIR / "data" / "empty_pyproject.toml"
197         args = [
198             "-",
199             "--fast",
200             f"--line-length={black.DEFAULT_LINE_LENGTH}",
201             "--diff",
202             "--color",
203             f"--config={config}",
204         ]
205         result = BlackRunner().invoke(
206             black.main, args, input=BytesIO(source.encode("utf8"))
207         )
208         actual = result.output
209         # Again, the contents are checked in a different test, so only look for colors.
210         self.assertIn("\033[1m", actual)
211         self.assertIn("\033[36m", actual)
212         self.assertIn("\033[32m", actual)
213         self.assertIn("\033[31m", actual)
214         self.assertIn("\033[0m", actual)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def _test_wip(self) -> None:
218         source, expected = read_data("wip")
219         sys.settrace(tracefunc)
220         mode = replace(
221             DEFAULT_MODE,
222             experimental_string_processing=False,
223             target_versions={black.TargetVersion.PY38},
224         )
225         actual = fs(source, mode=mode)
226         sys.settrace(None)
227         self.assertFormatEqual(expected, actual)
228         black.assert_equivalent(source, actual)
229         black.assert_stable(source, actual, black.FileMode())
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability1(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens1")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability2(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens2")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @unittest.expectedFailure
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_trailing_comma_optional_parens_stability3(self) -> None:
248         source, _expected = read_data("trailing_comma_optional_parens3")
249         actual = fs(source)
250         black.assert_stable(source, actual, DEFAULT_MODE)
251
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens1")
255         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens2")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens3")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     def test_pep_572_version_detection(self) -> None:
271         source, _ = read_data("pep_572")
272         root = black.lib2to3_parse(source)
273         features = black.get_features_used(root)
274         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
275         versions = black.detect_target_versions(root)
276         self.assertIn(black.TargetVersion.PY38, versions)
277
278     def test_expression_ff(self) -> None:
279         source, expected = read_data("expression")
280         tmp_file = Path(black.dump_to_file(source))
281         try:
282             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
283             with open(tmp_file, encoding="utf8") as f:
284                 actual = f.read()
285         finally:
286             os.unlink(tmp_file)
287         self.assertFormatEqual(expected, actual)
288         with patch("black.dump_to_file", dump_to_stderr):
289             black.assert_equivalent(source, actual)
290             black.assert_stable(source, actual, DEFAULT_MODE)
291
292     def test_expression_diff(self) -> None:
293         source, _ = read_data("expression.py")
294         config = THIS_DIR / "data" / "empty_pyproject.toml"
295         expected, _ = read_data("expression.diff")
296         tmp_file = Path(black.dump_to_file(source))
297         diff_header = re.compile(
298             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
299             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
300         )
301         try:
302             result = BlackRunner().invoke(
303                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
304             )
305             self.assertEqual(result.exit_code, 0)
306         finally:
307             os.unlink(tmp_file)
308         actual = result.output
309         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
310         if expected != actual:
311             dump = black.dump_to_file(actual)
312             msg = (
313                 "Expected diff isn't equal to the actual. If you made changes to"
314                 " expression.py and this is an anticipated difference, overwrite"
315                 f" tests/data/expression.diff with {dump}"
316             )
317             self.assertEqual(expected, actual, msg)
318
319     def test_expression_diff_with_color(self) -> None:
320         source, _ = read_data("expression.py")
321         config = THIS_DIR / "data" / "empty_pyproject.toml"
322         expected, _ = read_data("expression.diff")
323         tmp_file = Path(black.dump_to_file(source))
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
327             )
328         finally:
329             os.unlink(tmp_file)
330         actual = result.output
331         # We check the contents of the diff in `test_expression_diff`. All
332         # we need to check here is that color codes exist in the result.
333         self.assertIn("\033[1m", actual)
334         self.assertIn("\033[36m", actual)
335         self.assertIn("\033[32m", actual)
336         self.assertIn("\033[31m", actual)
337         self.assertIn("\033[0m", actual)
338
339     def test_detect_pos_only_arguments(self) -> None:
340         source, _ = read_data("pep_570")
341         root = black.lib2to3_parse(source)
342         features = black.get_features_used(root)
343         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
344         versions = black.detect_target_versions(root)
345         self.assertIn(black.TargetVersion.PY38, versions)
346
347     @patch("black.dump_to_file", dump_to_stderr)
348     def test_string_quotes(self) -> None:
349         source, expected = read_data("string_quotes")
350         mode = black.Mode(preview=True)
351         assert_format(source, expected, mode)
352         mode = replace(mode, string_normalization=False)
353         not_normalized = fs(source, mode=mode)
354         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
355         black.assert_equivalent(source, not_normalized)
356         black.assert_stable(source, not_normalized, mode=mode)
357
358     def test_skip_magic_trailing_comma(self) -> None:
359         source, _ = read_data("expression.py")
360         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
361         tmp_file = Path(black.dump_to_file(source))
362         diff_header = re.compile(
363             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
364             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
365         )
366         try:
367             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
368             self.assertEqual(result.exit_code, 0)
369         finally:
370             os.unlink(tmp_file)
371         actual = result.output
372         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
373         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
374         if expected != actual:
375             dump = black.dump_to_file(actual)
376             msg = (
377                 "Expected diff isn't equal to the actual. If you made changes to"
378                 " expression.py and this is an anticipated difference, overwrite"
379                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
380             )
381             self.assertEqual(expected, actual, msg)
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_async_as_identifier(self) -> None:
385         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
386         source, expected = read_data("async_as_identifier")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         major, minor = sys.version_info[:2]
390         if major < 3 or (major <= 3 and minor < 7):
391             black.assert_equivalent(source, actual)
392         black.assert_stable(source, actual, DEFAULT_MODE)
393         # ensure black can parse this when the target is 3.6
394         self.invokeBlack([str(source_path), "--target-version", "py36"])
395         # but not on 3.7, because async/await is no longer an identifier
396         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_python37(self) -> None:
400         source_path = (THIS_DIR / "data" / "python37.py").resolve()
401         source, expected = read_data("python37")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         major, minor = sys.version_info[:2]
405         if major > 3 or (major == 3 and minor >= 7):
406             black.assert_equivalent(source, actual)
407         black.assert_stable(source, actual, DEFAULT_MODE)
408         # ensure black can parse this when the target is 3.7
409         self.invokeBlack([str(source_path), "--target-version", "py37"])
410         # but not on 3.6, because we use async as a reserved keyword
411         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
412
413     def test_tab_comment_indentation(self) -> None:
414         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
415         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
416         self.assertFormatEqual(contents_spc, fs(contents_spc))
417         self.assertFormatEqual(contents_spc, fs(contents_tab))
418
419         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
420         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
421         self.assertFormatEqual(contents_spc, fs(contents_spc))
422         self.assertFormatEqual(contents_spc, fs(contents_tab))
423
424         # mixed tabs and spaces (valid Python 2 code)
425         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
426         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
427         self.assertFormatEqual(contents_spc, fs(contents_spc))
428         self.assertFormatEqual(contents_spc, fs(contents_tab))
429
430         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
431         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
432         self.assertFormatEqual(contents_spc, fs(contents_spc))
433         self.assertFormatEqual(contents_spc, fs(contents_tab))
434
435     def test_report_verbose(self) -> None:
436         report = Report(verbose=True)
437         out_lines = []
438         err_lines = []
439
440         def out(msg: str, **kwargs: Any) -> None:
441             out_lines.append(msg)
442
443         def err(msg: str, **kwargs: Any) -> None:
444             err_lines.append(msg)
445
446         with patch("black.output._out", out), patch("black.output._err", err):
447             report.done(Path("f1"), black.Changed.NO)
448             self.assertEqual(len(out_lines), 1)
449             self.assertEqual(len(err_lines), 0)
450             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
451             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
452             self.assertEqual(report.return_code, 0)
453             report.done(Path("f2"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 2)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(out_lines[-1], "reformatted f2")
457             self.assertEqual(
458                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
459             )
460             report.done(Path("f3"), black.Changed.CACHED)
461             self.assertEqual(len(out_lines), 3)
462             self.assertEqual(len(err_lines), 0)
463             self.assertEqual(
464                 out_lines[-1], "f3 wasn't modified on disk since last run."
465             )
466             self.assertEqual(
467                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
468             )
469             self.assertEqual(report.return_code, 0)
470             report.check = True
471             self.assertEqual(report.return_code, 1)
472             report.check = False
473             report.failed(Path("e1"), "boom")
474             self.assertEqual(len(out_lines), 3)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f3"), black.Changed.YES)
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 1)
486             self.assertEqual(out_lines[-1], "reformatted f3")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.failed(Path("e2"), "boom")
494             self.assertEqual(len(out_lines), 4)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.path_ignored(Path("wat"), "no match")
504             self.assertEqual(len(out_lines), 5)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "wat ignored: no match")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.done(Path("f4"), black.Changed.NO)
514             self.assertEqual(len(out_lines), 6)
515             self.assertEqual(len(err_lines), 2)
516             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
520                 " reformat.",
521             )
522             self.assertEqual(report.return_code, 123)
523             report.check = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529             report.check = False
530             report.diff = True
531             self.assertEqual(
532                 unstyle(str(report)),
533                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
534                 " would fail to reformat.",
535             )
536
537     def test_report_quiet(self) -> None:
538         report = Report(quiet=True)
539         out_lines = []
540         err_lines = []
541
542         def out(msg: str, **kwargs: Any) -> None:
543             out_lines.append(msg)
544
545         def err(msg: str, **kwargs: Any) -> None:
546             err_lines.append(msg)
547
548         with patch("black.output._out", out), patch("black.output._err", err):
549             report.done(Path("f1"), black.Changed.NO)
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
553             self.assertEqual(report.return_code, 0)
554             report.done(Path("f2"), black.Changed.YES)
555             self.assertEqual(len(out_lines), 0)
556             self.assertEqual(len(err_lines), 0)
557             self.assertEqual(
558                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
559             )
560             report.done(Path("f3"), black.Changed.CACHED)
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(
564                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
565             )
566             self.assertEqual(report.return_code, 0)
567             report.check = True
568             self.assertEqual(report.return_code, 1)
569             report.check = False
570             report.failed(Path("e1"), "boom")
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
574             self.assertEqual(
575                 unstyle(str(report)),
576                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
577                 " reformat.",
578             )
579             self.assertEqual(report.return_code, 123)
580             report.done(Path("f3"), black.Changed.YES)
581             self.assertEqual(len(out_lines), 0)
582             self.assertEqual(len(err_lines), 1)
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.failed(Path("e2"), "boom")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
596                 " reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.path_ignored(Path("wat"), "no match")
600             self.assertEqual(len(out_lines), 0)
601             self.assertEqual(len(err_lines), 2)
602             self.assertEqual(
603                 unstyle(str(report)),
604                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
605                 " reformat.",
606             )
607             self.assertEqual(report.return_code, 123)
608             report.done(Path("f4"), black.Changed.NO)
609             self.assertEqual(len(out_lines), 0)
610             self.assertEqual(len(err_lines), 2)
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
614                 " reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.check = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623             report.check = False
624             report.diff = True
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
628                 " would fail to reformat.",
629             )
630
631     def test_report_normal(self) -> None:
632         report = black.Report()
633         out_lines = []
634         err_lines = []
635
636         def out(msg: str, **kwargs: Any) -> None:
637             out_lines.append(msg)
638
639         def err(msg: str, **kwargs: Any) -> None:
640             err_lines.append(msg)
641
642         with patch("black.output._out", out), patch("black.output._err", err):
643             report.done(Path("f1"), black.Changed.NO)
644             self.assertEqual(len(out_lines), 0)
645             self.assertEqual(len(err_lines), 0)
646             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
647             self.assertEqual(report.return_code, 0)
648             report.done(Path("f2"), black.Changed.YES)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
654             )
655             report.done(Path("f3"), black.Changed.CACHED)
656             self.assertEqual(len(out_lines), 1)
657             self.assertEqual(len(err_lines), 0)
658             self.assertEqual(out_lines[-1], "reformatted f2")
659             self.assertEqual(
660                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
661             )
662             self.assertEqual(report.return_code, 0)
663             report.check = True
664             self.assertEqual(report.return_code, 1)
665             report.check = False
666             report.failed(Path("e1"), "boom")
667             self.assertEqual(len(out_lines), 1)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.done(Path("f3"), black.Changed.YES)
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 1)
679             self.assertEqual(out_lines[-1], "reformatted f3")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.failed(Path("e2"), "boom")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
693                 " reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.path_ignored(Path("wat"), "no match")
697             self.assertEqual(len(out_lines), 2)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
702                 " reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.done(Path("f4"), black.Changed.NO)
706             self.assertEqual(len(out_lines), 2)
707             self.assertEqual(len(err_lines), 2)
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
711                 " reformat.",
712             )
713             self.assertEqual(report.return_code, 123)
714             report.check = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720             report.check = False
721             report.diff = True
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
725                 " would fail to reformat.",
726             )
727
728     def test_lib2to3_parse(self) -> None:
729         with self.assertRaises(black.InvalidInput):
730             black.lib2to3_parse("invalid syntax")
731
732         straddling = "x + y"
733         black.lib2to3_parse(straddling)
734         black.lib2to3_parse(straddling, {TargetVersion.PY36})
735
736         py2_only = "print x"
737         with self.assertRaises(black.InvalidInput):
738             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
739
740         py3_only = "exec(x, end=y)"
741         black.lib2to3_parse(py3_only)
742         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
743
744     def test_get_features_used_decorator(self) -> None:
745         # Test the feature detection of new decorator syntax
746         # since this makes some test cases of test_get_features_used()
747         # fails if it fails, this is tested first so that a useful case
748         # is identified
749         simples, relaxed = read_data("decorators")
750         # skip explanation comments at the top of the file
751         for simple_test in simples.split("##")[1:]:
752             node = black.lib2to3_parse(simple_test)
753             decorator = str(node.children[0].children[0]).strip()
754             self.assertNotIn(
755                 Feature.RELAXED_DECORATORS,
756                 black.get_features_used(node),
757                 msg=(
758                     f"decorator '{decorator}' follows python<=3.8 syntax"
759                     "but is detected as 3.9+"
760                     # f"The full node is\n{node!r}"
761                 ),
762             )
763         # skip the '# output' comment at the top of the output part
764         for relaxed_test in relaxed.split("##")[1:]:
765             node = black.lib2to3_parse(relaxed_test)
766             decorator = str(node.children[0].children[0]).strip()
767             self.assertIn(
768                 Feature.RELAXED_DECORATORS,
769                 black.get_features_used(node),
770                 msg=(
771                     f"decorator '{decorator}' uses python3.9+ syntax"
772                     "but is detected as python<=3.8"
773                     # f"The full node is\n{node!r}"
774                 ),
775             )
776
777     def test_get_features_used(self) -> None:
778         node = black.lib2to3_parse("def f(*, arg): ...\n")
779         self.assertEqual(black.get_features_used(node), set())
780         node = black.lib2to3_parse("def f(*, arg,): ...\n")
781         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
782         node = black.lib2to3_parse("f(*arg,)\n")
783         self.assertEqual(
784             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
785         )
786         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
787         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
788         node = black.lib2to3_parse("123_456\n")
789         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
790         node = black.lib2to3_parse("123456\n")
791         self.assertEqual(black.get_features_used(node), set())
792         source, expected = read_data("function")
793         node = black.lib2to3_parse(source)
794         expected_features = {
795             Feature.TRAILING_COMMA_IN_CALL,
796             Feature.TRAILING_COMMA_IN_DEF,
797             Feature.F_STRINGS,
798         }
799         self.assertEqual(black.get_features_used(node), expected_features)
800         node = black.lib2to3_parse(expected)
801         self.assertEqual(black.get_features_used(node), expected_features)
802         source, expected = read_data("expression")
803         node = black.lib2to3_parse(source)
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse(expected)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse("lambda a, /, b: ...")
808         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
809         node = black.lib2to3_parse("def fn(a, /, b): ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(): yield a, b")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("def fn(): return a, b")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("def fn(): yield *b, c")
816         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
817         node = black.lib2to3_parse("def fn(): return a, *b, c")
818         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
819         node = black.lib2to3_parse("x = a, *b, c")
820         self.assertEqual(black.get_features_used(node), set())
821         node = black.lib2to3_parse("x: Any = regular")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("x: Any = (regular, regular)")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
826         self.assertEqual(black.get_features_used(node), set())
827         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
828         self.assertEqual(
829             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
830         )
831
832     def test_get_features_used_for_future_flags(self) -> None:
833         for src, features in [
834             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
835             (
836                 "from __future__ import (other, annotations)",
837                 {Feature.FUTURE_ANNOTATIONS},
838             ),
839             ("a = 1 + 2\nfrom something import annotations", set()),
840             ("from __future__ import x, y", set()),
841         ]:
842             with self.subTest(src=src, features=features):
843                 node = black.lib2to3_parse(src)
844                 future_imports = black.get_future_imports(node)
845                 self.assertEqual(
846                     black.get_features_used(node, future_imports=future_imports),
847                     features,
848                 )
849
850     def test_get_future_imports(self) -> None:
851         node = black.lib2to3_parse("\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse("from __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
856         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
858         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse(
860             "from __future__ import multiple\nfrom __future__ import imports\n"
861         )
862         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
863         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
864         self.assertEqual({"black"}, black.get_future_imports(node))
865         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
868         self.assertEqual(set(), black.get_future_imports(node))
869         node = black.lib2to3_parse("from some.module import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse(
872             "from __future__ import unicode_literals as _unicode_literals"
873         )
874         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
875         node = black.lib2to3_parse(
876             "from __future__ import unicode_literals as _lol, print"
877         )
878         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
879
880     @pytest.mark.incompatible_with_mypyc
881     def test_debug_visitor(self) -> None:
882         source, _ = read_data("debug_visitor.py")
883         expected, _ = read_data("debug_visitor.out")
884         out_lines = []
885         err_lines = []
886
887         def out(msg: str, **kwargs: Any) -> None:
888             out_lines.append(msg)
889
890         def err(msg: str, **kwargs: Any) -> None:
891             err_lines.append(msg)
892
893         with patch("black.debug.out", out):
894             DebugVisitor.show(source)
895         actual = "\n".join(out_lines) + "\n"
896         log_name = ""
897         if expected != actual:
898             log_name = black.dump_to_file(*out_lines)
899         self.assertEqual(
900             expected,
901             actual,
902             f"AST print out is different. Actual version dumped to {log_name}",
903         )
904
905     def test_format_file_contents(self) -> None:
906         empty = ""
907         mode = DEFAULT_MODE
908         with self.assertRaises(black.NothingChanged):
909             black.format_file_contents(empty, mode=mode, fast=False)
910         just_nl = "\n"
911         with self.assertRaises(black.NothingChanged):
912             black.format_file_contents(just_nl, mode=mode, fast=False)
913         same = "j = [1, 2, 3]\n"
914         with self.assertRaises(black.NothingChanged):
915             black.format_file_contents(same, mode=mode, fast=False)
916         different = "j = [1,2,3]"
917         expected = same
918         actual = black.format_file_contents(different, mode=mode, fast=False)
919         self.assertEqual(expected, actual)
920         invalid = "return if you can"
921         with self.assertRaises(black.InvalidInput) as e:
922             black.format_file_contents(invalid, mode=mode, fast=False)
923         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
924
925     def test_endmarker(self) -> None:
926         n = black.lib2to3_parse("\n")
927         self.assertEqual(n.type, black.syms.file_input)
928         self.assertEqual(len(n.children), 1)
929         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
930
931     @pytest.mark.incompatible_with_mypyc
932     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
933     def test_assertFormatEqual(self) -> None:
934         out_lines = []
935         err_lines = []
936
937         def out(msg: str, **kwargs: Any) -> None:
938             out_lines.append(msg)
939
940         def err(msg: str, **kwargs: Any) -> None:
941             err_lines.append(msg)
942
943         with patch("black.output._out", out), patch("black.output._err", err):
944             with self.assertRaises(AssertionError):
945                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
946
947         out_str = "".join(out_lines)
948         self.assertIn("Expected tree:", out_str)
949         self.assertIn("Actual tree:", out_str)
950         self.assertEqual("".join(err_lines), "")
951
952     @event_loop()
953     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
954     def test_works_in_mono_process_only_environment(self) -> None:
955         with cache_dir() as workspace:
956             for f in [
957                 (workspace / "one.py").resolve(),
958                 (workspace / "two.py").resolve(),
959             ]:
960                 f.write_text('print("hello")\n')
961             self.invokeBlack([str(workspace)])
962
963     @event_loop()
964     def test_check_diff_use_together(self) -> None:
965         with cache_dir():
966             # Files which will be reformatted.
967             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
968             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
969             # Files which will not be reformatted.
970             src2 = (THIS_DIR / "data" / "composition.py").resolve()
971             self.invokeBlack([str(src2), "--diff", "--check"])
972             # Multi file command.
973             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
974
975     def test_no_files(self) -> None:
976         with cache_dir():
977             # Without an argument, black exits with error code 0.
978             self.invokeBlack([])
979
980     def test_broken_symlink(self) -> None:
981         with cache_dir() as workspace:
982             symlink = workspace / "broken_link.py"
983             try:
984                 symlink.symlink_to("nonexistent.py")
985             except (OSError, NotImplementedError) as e:
986                 self.skipTest(f"Can't create symlinks: {e}")
987             self.invokeBlack([str(workspace.resolve())])
988
989     def test_single_file_force_pyi(self) -> None:
990         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
991         contents, expected = read_data("force_pyi")
992         with cache_dir() as workspace:
993             path = (workspace / "file.py").resolve()
994             with open(path, "w") as fh:
995                 fh.write(contents)
996             self.invokeBlack([str(path), "--pyi"])
997             with open(path, "r") as fh:
998                 actual = fh.read()
999             # verify cache with --pyi is separate
1000             pyi_cache = black.read_cache(pyi_mode)
1001             self.assertIn(str(path), pyi_cache)
1002             normal_cache = black.read_cache(DEFAULT_MODE)
1003             self.assertNotIn(str(path), normal_cache)
1004         self.assertFormatEqual(expected, actual)
1005         black.assert_equivalent(contents, actual)
1006         black.assert_stable(contents, actual, pyi_mode)
1007
1008     @event_loop()
1009     def test_multi_file_force_pyi(self) -> None:
1010         reg_mode = DEFAULT_MODE
1011         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1012         contents, expected = read_data("force_pyi")
1013         with cache_dir() as workspace:
1014             paths = [
1015                 (workspace / "file1.py").resolve(),
1016                 (workspace / "file2.py").resolve(),
1017             ]
1018             for path in paths:
1019                 with open(path, "w") as fh:
1020                     fh.write(contents)
1021             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1022             for path in paths:
1023                 with open(path, "r") as fh:
1024                     actual = fh.read()
1025                 self.assertEqual(actual, expected)
1026             # verify cache with --pyi is separate
1027             pyi_cache = black.read_cache(pyi_mode)
1028             normal_cache = black.read_cache(reg_mode)
1029             for path in paths:
1030                 self.assertIn(str(path), pyi_cache)
1031                 self.assertNotIn(str(path), normal_cache)
1032
1033     def test_pipe_force_pyi(self) -> None:
1034         source, expected = read_data("force_pyi")
1035         result = CliRunner().invoke(
1036             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1037         )
1038         self.assertEqual(result.exit_code, 0)
1039         actual = result.output
1040         self.assertFormatEqual(actual, expected)
1041
1042     def test_single_file_force_py36(self) -> None:
1043         reg_mode = DEFAULT_MODE
1044         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1045         source, expected = read_data("force_py36")
1046         with cache_dir() as workspace:
1047             path = (workspace / "file.py").resolve()
1048             with open(path, "w") as fh:
1049                 fh.write(source)
1050             self.invokeBlack([str(path), *PY36_ARGS])
1051             with open(path, "r") as fh:
1052                 actual = fh.read()
1053             # verify cache with --target-version is separate
1054             py36_cache = black.read_cache(py36_mode)
1055             self.assertIn(str(path), py36_cache)
1056             normal_cache = black.read_cache(reg_mode)
1057             self.assertNotIn(str(path), normal_cache)
1058         self.assertEqual(actual, expected)
1059
1060     @event_loop()
1061     def test_multi_file_force_py36(self) -> None:
1062         reg_mode = DEFAULT_MODE
1063         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1064         source, expected = read_data("force_py36")
1065         with cache_dir() as workspace:
1066             paths = [
1067                 (workspace / "file1.py").resolve(),
1068                 (workspace / "file2.py").resolve(),
1069             ]
1070             for path in paths:
1071                 with open(path, "w") as fh:
1072                     fh.write(source)
1073             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1074             for path in paths:
1075                 with open(path, "r") as fh:
1076                     actual = fh.read()
1077                 self.assertEqual(actual, expected)
1078             # verify cache with --target-version is separate
1079             pyi_cache = black.read_cache(py36_mode)
1080             normal_cache = black.read_cache(reg_mode)
1081             for path in paths:
1082                 self.assertIn(str(path), pyi_cache)
1083                 self.assertNotIn(str(path), normal_cache)
1084
1085     def test_pipe_force_py36(self) -> None:
1086         source, expected = read_data("force_py36")
1087         result = CliRunner().invoke(
1088             black.main,
1089             ["-", "-q", "--target-version=py36"],
1090             input=BytesIO(source.encode("utf8")),
1091         )
1092         self.assertEqual(result.exit_code, 0)
1093         actual = result.output
1094         self.assertFormatEqual(actual, expected)
1095
1096     @pytest.mark.incompatible_with_mypyc
1097     def test_reformat_one_with_stdin(self) -> None:
1098         with patch(
1099             "black.format_stdin_to_stdout",
1100             return_value=lambda *args, **kwargs: black.Changed.YES,
1101         ) as fsts:
1102             report = MagicMock()
1103             path = Path("-")
1104             black.reformat_one(
1105                 path,
1106                 fast=True,
1107                 write_back=black.WriteBack.YES,
1108                 mode=DEFAULT_MODE,
1109                 report=report,
1110             )
1111             fsts.assert_called_once()
1112             report.done.assert_called_with(path, black.Changed.YES)
1113
1114     @pytest.mark.incompatible_with_mypyc
1115     def test_reformat_one_with_stdin_filename(self) -> None:
1116         with patch(
1117             "black.format_stdin_to_stdout",
1118             return_value=lambda *args, **kwargs: black.Changed.YES,
1119         ) as fsts:
1120             report = MagicMock()
1121             p = "foo.py"
1122             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1123             expected = Path(p)
1124             black.reformat_one(
1125                 path,
1126                 fast=True,
1127                 write_back=black.WriteBack.YES,
1128                 mode=DEFAULT_MODE,
1129                 report=report,
1130             )
1131             fsts.assert_called_once_with(
1132                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1133             )
1134             # __BLACK_STDIN_FILENAME__ should have been stripped
1135             report.done.assert_called_with(expected, black.Changed.YES)
1136
1137     @pytest.mark.incompatible_with_mypyc
1138     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1139         with patch(
1140             "black.format_stdin_to_stdout",
1141             return_value=lambda *args, **kwargs: black.Changed.YES,
1142         ) as fsts:
1143             report = MagicMock()
1144             p = "foo.pyi"
1145             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1146             expected = Path(p)
1147             black.reformat_one(
1148                 path,
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=DEFAULT_MODE,
1152                 report=report,
1153             )
1154             fsts.assert_called_once_with(
1155                 fast=True,
1156                 write_back=black.WriteBack.YES,
1157                 mode=replace(DEFAULT_MODE, is_pyi=True),
1158             )
1159             # __BLACK_STDIN_FILENAME__ should have been stripped
1160             report.done.assert_called_with(expected, black.Changed.YES)
1161
1162     @pytest.mark.incompatible_with_mypyc
1163     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1164         with patch(
1165             "black.format_stdin_to_stdout",
1166             return_value=lambda *args, **kwargs: black.Changed.YES,
1167         ) as fsts:
1168             report = MagicMock()
1169             p = "foo.ipynb"
1170             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1171             expected = Path(p)
1172             black.reformat_one(
1173                 path,
1174                 fast=True,
1175                 write_back=black.WriteBack.YES,
1176                 mode=DEFAULT_MODE,
1177                 report=report,
1178             )
1179             fsts.assert_called_once_with(
1180                 fast=True,
1181                 write_back=black.WriteBack.YES,
1182                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1183             )
1184             # __BLACK_STDIN_FILENAME__ should have been stripped
1185             report.done.assert_called_with(expected, black.Changed.YES)
1186
1187     @pytest.mark.incompatible_with_mypyc
1188     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1189         with patch(
1190             "black.format_stdin_to_stdout",
1191             return_value=lambda *args, **kwargs: black.Changed.YES,
1192         ) as fsts:
1193             report = MagicMock()
1194             # Even with an existing file, since we are forcing stdin, black
1195             # should output to stdout and not modify the file inplace
1196             p = Path(str(THIS_DIR / "data/collections.py"))
1197             # Make sure is_file actually returns True
1198             self.assertTrue(p.is_file())
1199             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1200             expected = Path(p)
1201             black.reformat_one(
1202                 path,
1203                 fast=True,
1204                 write_back=black.WriteBack.YES,
1205                 mode=DEFAULT_MODE,
1206                 report=report,
1207             )
1208             fsts.assert_called_once()
1209             # __BLACK_STDIN_FILENAME__ should have been stripped
1210             report.done.assert_called_with(expected, black.Changed.YES)
1211
1212     def test_reformat_one_with_stdin_empty(self) -> None:
1213         output = io.StringIO()
1214         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1215             try:
1216                 black.format_stdin_to_stdout(
1217                     fast=True,
1218                     content="",
1219                     write_back=black.WriteBack.YES,
1220                     mode=DEFAULT_MODE,
1221                 )
1222             except io.UnsupportedOperation:
1223                 pass  # StringIO does not support detach
1224             assert output.getvalue() == ""
1225
1226     def test_invalid_cli_regex(self) -> None:
1227         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1228             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1229
1230     def test_required_version_matches_version(self) -> None:
1231         self.invokeBlack(
1232             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1233         )
1234
1235     def test_required_version_does_not_match_version(self) -> None:
1236         self.invokeBlack(
1237             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1238         )
1239
1240     def test_preserves_line_endings(self) -> None:
1241         with TemporaryDirectory() as workspace:
1242             test_file = Path(workspace) / "test.py"
1243             for nl in ["\n", "\r\n"]:
1244                 contents = nl.join(["def f(  ):", "    pass"])
1245                 test_file.write_bytes(contents.encode())
1246                 ff(test_file, write_back=black.WriteBack.YES)
1247                 updated_contents: bytes = test_file.read_bytes()
1248                 self.assertIn(nl.encode(), updated_contents)
1249                 if nl == "\n":
1250                     self.assertNotIn(b"\r\n", updated_contents)
1251
1252     def test_preserves_line_endings_via_stdin(self) -> None:
1253         for nl in ["\n", "\r\n"]:
1254             contents = nl.join(["def f(  ):", "    pass"])
1255             runner = BlackRunner()
1256             result = runner.invoke(
1257                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1258             )
1259             self.assertEqual(result.exit_code, 0)
1260             output = result.stdout_bytes
1261             self.assertIn(nl.encode("utf8"), output)
1262             if nl == "\n":
1263                 self.assertNotIn(b"\r\n", output)
1264
1265     def test_assert_equivalent_different_asts(self) -> None:
1266         with self.assertRaises(AssertionError):
1267             black.assert_equivalent("{}", "None")
1268
1269     def test_shhh_click(self) -> None:
1270         try:
1271             from click import _unicodefun
1272         except ModuleNotFoundError:
1273             self.skipTest("Incompatible Click version")
1274         if not hasattr(_unicodefun, "_verify_python3_env"):
1275             self.skipTest("Incompatible Click version")
1276         # First, let's see if Click is crashing with a preferred ASCII charset.
1277         with patch("locale.getpreferredencoding") as gpe:
1278             gpe.return_value = "ASCII"
1279             with self.assertRaises(RuntimeError):
1280                 _unicodefun._verify_python3_env()  # type: ignore
1281         # Now, let's silence Click...
1282         black.patch_click()
1283         # ...and confirm it's silent.
1284         with patch("locale.getpreferredencoding") as gpe:
1285             gpe.return_value = "ASCII"
1286             try:
1287                 _unicodefun._verify_python3_env()  # type: ignore
1288             except RuntimeError as re:
1289                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1290
1291     def test_root_logger_not_used_directly(self) -> None:
1292         def fail(*args: Any, **kwargs: Any) -> None:
1293             self.fail("Record created with root logger")
1294
1295         with patch.multiple(
1296             logging.root,
1297             debug=fail,
1298             info=fail,
1299             warning=fail,
1300             error=fail,
1301             critical=fail,
1302             log=fail,
1303         ):
1304             ff(THIS_DIR / "util.py")
1305
1306     def test_invalid_config_return_code(self) -> None:
1307         tmp_file = Path(black.dump_to_file())
1308         try:
1309             tmp_config = Path(black.dump_to_file())
1310             tmp_config.unlink()
1311             args = ["--config", str(tmp_config), str(tmp_file)]
1312             self.invokeBlack(args, exit_code=2, ignore_config=False)
1313         finally:
1314             tmp_file.unlink()
1315
1316     def test_parse_pyproject_toml(self) -> None:
1317         test_toml_file = THIS_DIR / "test.toml"
1318         config = black.parse_pyproject_toml(str(test_toml_file))
1319         self.assertEqual(config["verbose"], 1)
1320         self.assertEqual(config["check"], "no")
1321         self.assertEqual(config["diff"], "y")
1322         self.assertEqual(config["color"], True)
1323         self.assertEqual(config["line_length"], 79)
1324         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1325         self.assertEqual(config["exclude"], r"\.pyi?$")
1326         self.assertEqual(config["include"], r"\.py?$")
1327
1328     def test_read_pyproject_toml(self) -> None:
1329         test_toml_file = THIS_DIR / "test.toml"
1330         fake_ctx = FakeContext()
1331         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1332         config = fake_ctx.default_map
1333         self.assertEqual(config["verbose"], "1")
1334         self.assertEqual(config["check"], "no")
1335         self.assertEqual(config["diff"], "y")
1336         self.assertEqual(config["color"], "True")
1337         self.assertEqual(config["line_length"], "79")
1338         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1339         self.assertEqual(config["exclude"], r"\.pyi?$")
1340         self.assertEqual(config["include"], r"\.py?$")
1341
1342     @pytest.mark.incompatible_with_mypyc
1343     def test_find_project_root(self) -> None:
1344         with TemporaryDirectory() as workspace:
1345             root = Path(workspace)
1346             test_dir = root / "test"
1347             test_dir.mkdir()
1348
1349             src_dir = root / "src"
1350             src_dir.mkdir()
1351
1352             root_pyproject = root / "pyproject.toml"
1353             root_pyproject.touch()
1354             src_pyproject = src_dir / "pyproject.toml"
1355             src_pyproject.touch()
1356             src_python = src_dir / "foo.py"
1357             src_python.touch()
1358
1359             self.assertEqual(
1360                 black.find_project_root((src_dir, test_dir)),
1361                 (root.resolve(), "pyproject.toml"),
1362             )
1363             self.assertEqual(
1364                 black.find_project_root((src_dir,)),
1365                 (src_dir.resolve(), "pyproject.toml"),
1366             )
1367             self.assertEqual(
1368                 black.find_project_root((src_python,)),
1369                 (src_dir.resolve(), "pyproject.toml"),
1370             )
1371
1372     @patch(
1373         "black.files.find_user_pyproject_toml",
1374         black.files.find_user_pyproject_toml.__wrapped__,
1375     )
1376     def test_find_user_pyproject_toml_linux(self) -> None:
1377         if system() == "Windows":
1378             return
1379
1380         # Test if XDG_CONFIG_HOME is checked
1381         with TemporaryDirectory() as workspace:
1382             tmp_user_config = Path(workspace) / "black"
1383             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1384                 self.assertEqual(
1385                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1386                 )
1387
1388         # Test fallback for XDG_CONFIG_HOME
1389         with patch.dict("os.environ"):
1390             os.environ.pop("XDG_CONFIG_HOME", None)
1391             fallback_user_config = Path("~/.config").expanduser() / "black"
1392             self.assertEqual(
1393                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1394             )
1395
1396     def test_find_user_pyproject_toml_windows(self) -> None:
1397         if system() != "Windows":
1398             return
1399
1400         user_config_path = Path.home() / ".black"
1401         self.assertEqual(
1402             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1403         )
1404
1405     def test_bpo_33660_workaround(self) -> None:
1406         if system() == "Windows":
1407             return
1408
1409         # https://bugs.python.org/issue33660
1410         root = Path("/")
1411         with change_directory(root):
1412             path = Path("workspace") / "project"
1413             report = black.Report(verbose=True)
1414             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1415             self.assertEqual(normalized_path, "workspace/project")
1416
1417     def test_newline_comment_interaction(self) -> None:
1418         source = "class A:\\\r\n# type: ignore\n pass\n"
1419         output = black.format_str(source, mode=DEFAULT_MODE)
1420         black.assert_stable(source, output, mode=DEFAULT_MODE)
1421
1422     def test_bpo_2142_workaround(self) -> None:
1423
1424         # https://bugs.python.org/issue2142
1425
1426         source, _ = read_data("missing_final_newline.py")
1427         # read_data adds a trailing newline
1428         source = source.rstrip()
1429         expected, _ = read_data("missing_final_newline.diff")
1430         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1431         diff_header = re.compile(
1432             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1433             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1434         )
1435         try:
1436             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1437             self.assertEqual(result.exit_code, 0)
1438         finally:
1439             os.unlink(tmp_file)
1440         actual = result.output
1441         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1442         self.assertEqual(actual, expected)
1443
1444     @staticmethod
1445     def compare_results(
1446         result: click.testing.Result, expected_value: str, expected_exit_code: int
1447     ) -> None:
1448         """Helper method to test the value and exit code of a click Result."""
1449         assert (
1450             result.output == expected_value
1451         ), "The output did not match the expected value."
1452         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1453
1454     def test_code_option(self) -> None:
1455         """Test the code option with no changes."""
1456         code = 'print("Hello world")\n'
1457         args = ["--code", code]
1458         result = CliRunner().invoke(black.main, args)
1459
1460         self.compare_results(result, code, 0)
1461
1462     def test_code_option_changed(self) -> None:
1463         """Test the code option when changes are required."""
1464         code = "print('hello world')"
1465         formatted = black.format_str(code, mode=DEFAULT_MODE)
1466
1467         args = ["--code", code]
1468         result = CliRunner().invoke(black.main, args)
1469
1470         self.compare_results(result, formatted, 0)
1471
1472     def test_code_option_check(self) -> None:
1473         """Test the code option when check is passed."""
1474         args = ["--check", "--code", 'print("Hello world")\n']
1475         result = CliRunner().invoke(black.main, args)
1476         self.compare_results(result, "", 0)
1477
1478     def test_code_option_check_changed(self) -> None:
1479         """Test the code option when changes are required, and check is passed."""
1480         args = ["--check", "--code", "print('hello world')"]
1481         result = CliRunner().invoke(black.main, args)
1482         self.compare_results(result, "", 1)
1483
1484     def test_code_option_diff(self) -> None:
1485         """Test the code option when diff is passed."""
1486         code = "print('hello world')"
1487         formatted = black.format_str(code, mode=DEFAULT_MODE)
1488         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1489
1490         args = ["--diff", "--code", code]
1491         result = CliRunner().invoke(black.main, args)
1492
1493         # Remove time from diff
1494         output = DIFF_TIME.sub("", result.output)
1495
1496         assert output == result_diff, "The output did not match the expected value."
1497         assert result.exit_code == 0, "The exit code is incorrect."
1498
1499     def test_code_option_color_diff(self) -> None:
1500         """Test the code option when color and diff are passed."""
1501         code = "print('hello world')"
1502         formatted = black.format_str(code, mode=DEFAULT_MODE)
1503
1504         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1505         result_diff = color_diff(result_diff)
1506
1507         args = ["--diff", "--color", "--code", code]
1508         result = CliRunner().invoke(black.main, args)
1509
1510         # Remove time from diff
1511         output = DIFF_TIME.sub("", result.output)
1512
1513         assert output == result_diff, "The output did not match the expected value."
1514         assert result.exit_code == 0, "The exit code is incorrect."
1515
1516     @pytest.mark.incompatible_with_mypyc
1517     def test_code_option_safe(self) -> None:
1518         """Test that the code option throws an error when the sanity checks fail."""
1519         # Patch black.assert_equivalent to ensure the sanity checks fail
1520         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1521             code = 'print("Hello world")'
1522             error_msg = f"{code}\nerror: cannot format <string>: \n"
1523
1524             args = ["--safe", "--code", code]
1525             result = CliRunner().invoke(black.main, args)
1526
1527             self.compare_results(result, error_msg, 123)
1528
1529     def test_code_option_fast(self) -> None:
1530         """Test that the code option ignores errors when the sanity checks fail."""
1531         # Patch black.assert_equivalent to ensure the sanity checks fail
1532         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1533             code = 'print("Hello world")'
1534             formatted = black.format_str(code, mode=DEFAULT_MODE)
1535
1536             args = ["--fast", "--code", code]
1537             result = CliRunner().invoke(black.main, args)
1538
1539             self.compare_results(result, formatted, 0)
1540
1541     @pytest.mark.incompatible_with_mypyc
1542     def test_code_option_config(self) -> None:
1543         """
1544         Test that the code option finds the pyproject.toml in the current directory.
1545         """
1546         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1547             args = ["--code", "print"]
1548             # This is the only directory known to contain a pyproject.toml
1549             with change_directory(PROJECT_ROOT):
1550                 CliRunner().invoke(black.main, args)
1551                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1552
1553             assert (
1554                 len(parse.mock_calls) >= 1
1555             ), "Expected config parse to be called with the current directory."
1556
1557             _, call_args, _ = parse.mock_calls[0]
1558             assert (
1559                 call_args[0].lower() == str(pyproject_path).lower()
1560             ), "Incorrect config loaded."
1561
1562     @pytest.mark.incompatible_with_mypyc
1563     def test_code_option_parent_config(self) -> None:
1564         """
1565         Test that the code option finds the pyproject.toml in the parent directory.
1566         """
1567         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1568             with change_directory(THIS_DIR):
1569                 args = ["--code", "print"]
1570                 CliRunner().invoke(black.main, args)
1571
1572                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1573                 assert (
1574                     len(parse.mock_calls) >= 1
1575                 ), "Expected config parse to be called with the current directory."
1576
1577                 _, call_args, _ = parse.mock_calls[0]
1578                 assert (
1579                     call_args[0].lower() == str(pyproject_path).lower()
1580                 ), "Incorrect config loaded."
1581
1582     def test_for_handled_unexpected_eof_error(self) -> None:
1583         """
1584         Test that an unexpected EOF SyntaxError is nicely presented.
1585         """
1586         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1587             black.lib2to3_parse("print(", {})
1588
1589         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1590
1591     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1592         with pytest.raises(AssertionError) as err:
1593             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1594
1595         err.match("--safe")
1596         # Unfortunately the SyntaxError message has changed in newer versions so we
1597         # can't match it directly.
1598         err.match("invalid character")
1599         err.match(r"\(<unknown>, line 1\)")
1600
1601
1602 class TestCaching:
1603     def test_cache_broken_file(self) -> None:
1604         mode = DEFAULT_MODE
1605         with cache_dir() as workspace:
1606             cache_file = get_cache_file(mode)
1607             cache_file.write_text("this is not a pickle")
1608             assert black.read_cache(mode) == {}
1609             src = (workspace / "test.py").resolve()
1610             src.write_text("print('hello')")
1611             invokeBlack([str(src)])
1612             cache = black.read_cache(mode)
1613             assert str(src) in cache
1614
1615     def test_cache_single_file_already_cached(self) -> None:
1616         mode = DEFAULT_MODE
1617         with cache_dir() as workspace:
1618             src = (workspace / "test.py").resolve()
1619             src.write_text("print('hello')")
1620             black.write_cache({}, [src], mode)
1621             invokeBlack([str(src)])
1622             assert src.read_text() == "print('hello')"
1623
1624     @event_loop()
1625     def test_cache_multiple_files(self) -> None:
1626         mode = DEFAULT_MODE
1627         with cache_dir() as workspace, patch(
1628             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1629         ):
1630             one = (workspace / "one.py").resolve()
1631             with one.open("w") as fobj:
1632                 fobj.write("print('hello')")
1633             two = (workspace / "two.py").resolve()
1634             with two.open("w") as fobj:
1635                 fobj.write("print('hello')")
1636             black.write_cache({}, [one], mode)
1637             invokeBlack([str(workspace)])
1638             with one.open("r") as fobj:
1639                 assert fobj.read() == "print('hello')"
1640             with two.open("r") as fobj:
1641                 assert fobj.read() == 'print("hello")\n'
1642             cache = black.read_cache(mode)
1643             assert str(one) in cache
1644             assert str(two) in cache
1645
1646     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1647     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1648         mode = DEFAULT_MODE
1649         with cache_dir() as workspace:
1650             src = (workspace / "test.py").resolve()
1651             with src.open("w") as fobj:
1652                 fobj.write("print('hello')")
1653             with patch("black.read_cache") as read_cache, patch(
1654                 "black.write_cache"
1655             ) as write_cache:
1656                 cmd = [str(src), "--diff"]
1657                 if color:
1658                     cmd.append("--color")
1659                 invokeBlack(cmd)
1660                 cache_file = get_cache_file(mode)
1661                 assert cache_file.exists() is False
1662                 write_cache.assert_not_called()
1663                 read_cache.assert_not_called()
1664
1665     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1666     @event_loop()
1667     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1668         with cache_dir() as workspace:
1669             for tag in range(0, 4):
1670                 src = (workspace / f"test{tag}.py").resolve()
1671                 with src.open("w") as fobj:
1672                     fobj.write("print('hello')")
1673             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1674                 cmd = ["--diff", str(workspace)]
1675                 if color:
1676                     cmd.append("--color")
1677                 invokeBlack(cmd, exit_code=0)
1678                 # this isn't quite doing what we want, but if it _isn't_
1679                 # called then we cannot be using the lock it provides
1680                 mgr.assert_called()
1681
1682     def test_no_cache_when_stdin(self) -> None:
1683         mode = DEFAULT_MODE
1684         with cache_dir():
1685             result = CliRunner().invoke(
1686                 black.main, ["-"], input=BytesIO(b"print('hello')")
1687             )
1688             assert not result.exit_code
1689             cache_file = get_cache_file(mode)
1690             assert not cache_file.exists()
1691
1692     def test_read_cache_no_cachefile(self) -> None:
1693         mode = DEFAULT_MODE
1694         with cache_dir():
1695             assert black.read_cache(mode) == {}
1696
1697     def test_write_cache_read_cache(self) -> None:
1698         mode = DEFAULT_MODE
1699         with cache_dir() as workspace:
1700             src = (workspace / "test.py").resolve()
1701             src.touch()
1702             black.write_cache({}, [src], mode)
1703             cache = black.read_cache(mode)
1704             assert str(src) in cache
1705             assert cache[str(src)] == black.get_cache_info(src)
1706
1707     def test_filter_cached(self) -> None:
1708         with TemporaryDirectory() as workspace:
1709             path = Path(workspace)
1710             uncached = (path / "uncached").resolve()
1711             cached = (path / "cached").resolve()
1712             cached_but_changed = (path / "changed").resolve()
1713             uncached.touch()
1714             cached.touch()
1715             cached_but_changed.touch()
1716             cache = {
1717                 str(cached): black.get_cache_info(cached),
1718                 str(cached_but_changed): (0.0, 0),
1719             }
1720             todo, done = black.filter_cached(
1721                 cache, {uncached, cached, cached_but_changed}
1722             )
1723             assert todo == {uncached, cached_but_changed}
1724             assert done == {cached}
1725
1726     def test_write_cache_creates_directory_if_needed(self) -> None:
1727         mode = DEFAULT_MODE
1728         with cache_dir(exists=False) as workspace:
1729             assert not workspace.exists()
1730             black.write_cache({}, [], mode)
1731             assert workspace.exists()
1732
1733     @event_loop()
1734     def test_failed_formatting_does_not_get_cached(self) -> None:
1735         mode = DEFAULT_MODE
1736         with cache_dir() as workspace, patch(
1737             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1738         ):
1739             failing = (workspace / "failing.py").resolve()
1740             with failing.open("w") as fobj:
1741                 fobj.write("not actually python")
1742             clean = (workspace / "clean.py").resolve()
1743             with clean.open("w") as fobj:
1744                 fobj.write('print("hello")\n')
1745             invokeBlack([str(workspace)], exit_code=123)
1746             cache = black.read_cache(mode)
1747             assert str(failing) not in cache
1748             assert str(clean) in cache
1749
1750     def test_write_cache_write_fail(self) -> None:
1751         mode = DEFAULT_MODE
1752         with cache_dir(), patch.object(Path, "open") as mock:
1753             mock.side_effect = OSError
1754             black.write_cache({}, [], mode)
1755
1756     def test_read_cache_line_lengths(self) -> None:
1757         mode = DEFAULT_MODE
1758         short_mode = replace(DEFAULT_MODE, line_length=1)
1759         with cache_dir() as workspace:
1760             path = (workspace / "file.py").resolve()
1761             path.touch()
1762             black.write_cache({}, [path], mode)
1763             one = black.read_cache(mode)
1764             assert str(path) in one
1765             two = black.read_cache(short_mode)
1766             assert str(path) not in two
1767
1768
1769 def assert_collected_sources(
1770     src: Sequence[Union[str, Path]],
1771     expected: Sequence[Union[str, Path]],
1772     *,
1773     ctx: Optional[FakeContext] = None,
1774     exclude: Optional[str] = None,
1775     include: Optional[str] = None,
1776     extend_exclude: Optional[str] = None,
1777     force_exclude: Optional[str] = None,
1778     stdin_filename: Optional[str] = None,
1779 ) -> None:
1780     gs_src = tuple(str(Path(s)) for s in src)
1781     gs_expected = [Path(s) for s in expected]
1782     gs_exclude = None if exclude is None else compile_pattern(exclude)
1783     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1784     gs_extend_exclude = (
1785         None if extend_exclude is None else compile_pattern(extend_exclude)
1786     )
1787     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1788     collected = black.get_sources(
1789         ctx=ctx or FakeContext(),
1790         src=gs_src,
1791         quiet=False,
1792         verbose=False,
1793         include=gs_include,
1794         exclude=gs_exclude,
1795         extend_exclude=gs_extend_exclude,
1796         force_exclude=gs_force_exclude,
1797         report=black.Report(),
1798         stdin_filename=stdin_filename,
1799     )
1800     assert sorted(collected) == sorted(gs_expected)
1801
1802
1803 class TestFileCollection:
1804     def test_include_exclude(self) -> None:
1805         path = THIS_DIR / "data" / "include_exclude_tests"
1806         src = [path]
1807         expected = [
1808             Path(path / "b/dont_exclude/a.py"),
1809             Path(path / "b/dont_exclude/a.pyi"),
1810         ]
1811         assert_collected_sources(
1812             src,
1813             expected,
1814             include=r"\.pyi?$",
1815             exclude=r"/exclude/|/\.definitely_exclude/",
1816         )
1817
1818     def test_gitignore_used_as_default(self) -> None:
1819         base = Path(DATA_DIR / "include_exclude_tests")
1820         expected = [
1821             base / "b/.definitely_exclude/a.py",
1822             base / "b/.definitely_exclude/a.pyi",
1823         ]
1824         src = [base / "b/"]
1825         ctx = FakeContext()
1826         ctx.obj["root"] = base
1827         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1828
1829     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1830     def test_exclude_for_issue_1572(self) -> None:
1831         # Exclude shouldn't touch files that were explicitly given to Black through the
1832         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1833         # https://github.com/psf/black/issues/1572
1834         path = DATA_DIR / "include_exclude_tests"
1835         src = [path / "b/exclude/a.py"]
1836         expected = [path / "b/exclude/a.py"]
1837         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1838
1839     def test_gitignore_exclude(self) -> None:
1840         path = THIS_DIR / "data" / "include_exclude_tests"
1841         include = re.compile(r"\.pyi?$")
1842         exclude = re.compile(r"")
1843         report = black.Report()
1844         gitignore = PathSpec.from_lines(
1845             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1846         )
1847         sources: List[Path] = []
1848         expected = [
1849             Path(path / "b/dont_exclude/a.py"),
1850             Path(path / "b/dont_exclude/a.pyi"),
1851         ]
1852         this_abs = THIS_DIR.resolve()
1853         sources.extend(
1854             black.gen_python_files(
1855                 path.iterdir(),
1856                 this_abs,
1857                 include,
1858                 exclude,
1859                 None,
1860                 None,
1861                 report,
1862                 gitignore,
1863                 verbose=False,
1864                 quiet=False,
1865             )
1866         )
1867         assert sorted(expected) == sorted(sources)
1868
1869     def test_nested_gitignore(self) -> None:
1870         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1871         include = re.compile(r"\.pyi?$")
1872         exclude = re.compile(r"")
1873         root_gitignore = black.files.get_gitignore(path)
1874         report = black.Report()
1875         expected: List[Path] = [
1876             Path(path / "x.py"),
1877             Path(path / "root/b.py"),
1878             Path(path / "root/c.py"),
1879             Path(path / "root/child/c.py"),
1880         ]
1881         this_abs = THIS_DIR.resolve()
1882         sources = list(
1883             black.gen_python_files(
1884                 path.iterdir(),
1885                 this_abs,
1886                 include,
1887                 exclude,
1888                 None,
1889                 None,
1890                 report,
1891                 root_gitignore,
1892                 verbose=False,
1893                 quiet=False,
1894             )
1895         )
1896         assert sorted(expected) == sorted(sources)
1897
1898     def test_invalid_gitignore(self) -> None:
1899         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1900         empty_config = path / "pyproject.toml"
1901         result = BlackRunner().invoke(
1902             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1903         )
1904         assert result.exit_code == 1
1905         assert result.stderr_bytes is not None
1906
1907         gitignore = path / ".gitignore"
1908         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1909
1910     def test_invalid_nested_gitignore(self) -> None:
1911         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1912         empty_config = path / "pyproject.toml"
1913         result = BlackRunner().invoke(
1914             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1915         )
1916         assert result.exit_code == 1
1917         assert result.stderr_bytes is not None
1918
1919         gitignore = path / "a" / ".gitignore"
1920         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1921
1922     def test_empty_include(self) -> None:
1923         path = DATA_DIR / "include_exclude_tests"
1924         src = [path]
1925         expected = [
1926             Path(path / "b/exclude/a.pie"),
1927             Path(path / "b/exclude/a.py"),
1928             Path(path / "b/exclude/a.pyi"),
1929             Path(path / "b/dont_exclude/a.pie"),
1930             Path(path / "b/dont_exclude/a.py"),
1931             Path(path / "b/dont_exclude/a.pyi"),
1932             Path(path / "b/.definitely_exclude/a.pie"),
1933             Path(path / "b/.definitely_exclude/a.py"),
1934             Path(path / "b/.definitely_exclude/a.pyi"),
1935             Path(path / ".gitignore"),
1936             Path(path / "pyproject.toml"),
1937         ]
1938         # Setting exclude explicitly to an empty string to block .gitignore usage.
1939         assert_collected_sources(src, expected, include="", exclude="")
1940
1941     def test_extend_exclude(self) -> None:
1942         path = DATA_DIR / "include_exclude_tests"
1943         src = [path]
1944         expected = [
1945             Path(path / "b/exclude/a.py"),
1946             Path(path / "b/dont_exclude/a.py"),
1947         ]
1948         assert_collected_sources(
1949             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1950         )
1951
1952     @pytest.mark.incompatible_with_mypyc
1953     def test_symlink_out_of_root_directory(self) -> None:
1954         path = MagicMock()
1955         root = THIS_DIR.resolve()
1956         child = MagicMock()
1957         include = re.compile(black.DEFAULT_INCLUDES)
1958         exclude = re.compile(black.DEFAULT_EXCLUDES)
1959         report = black.Report()
1960         gitignore = PathSpec.from_lines("gitwildmatch", [])
1961         # `child` should behave like a symlink which resolved path is clearly
1962         # outside of the `root` directory.
1963         path.iterdir.return_value = [child]
1964         child.resolve.return_value = Path("/a/b/c")
1965         child.as_posix.return_value = "/a/b/c"
1966         child.is_symlink.return_value = True
1967         try:
1968             list(
1969                 black.gen_python_files(
1970                     path.iterdir(),
1971                     root,
1972                     include,
1973                     exclude,
1974                     None,
1975                     None,
1976                     report,
1977                     gitignore,
1978                     verbose=False,
1979                     quiet=False,
1980                 )
1981             )
1982         except ValueError as ve:
1983             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1984         path.iterdir.assert_called_once()
1985         child.resolve.assert_called_once()
1986         child.is_symlink.assert_called_once()
1987         # `child` should behave like a strange file which resolved path is clearly
1988         # outside of the `root` directory.
1989         child.is_symlink.return_value = False
1990         with pytest.raises(ValueError):
1991             list(
1992                 black.gen_python_files(
1993                     path.iterdir(),
1994                     root,
1995                     include,
1996                     exclude,
1997                     None,
1998                     None,
1999                     report,
2000                     gitignore,
2001                     verbose=False,
2002                     quiet=False,
2003                 )
2004             )
2005         path.iterdir.assert_called()
2006         assert path.iterdir.call_count == 2
2007         child.resolve.assert_called()
2008         assert child.resolve.call_count == 2
2009         child.is_symlink.assert_called()
2010         assert child.is_symlink.call_count == 2
2011
2012     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2013     def test_get_sources_with_stdin(self) -> None:
2014         src = ["-"]
2015         expected = ["-"]
2016         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2017
2018     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2019     def test_get_sources_with_stdin_filename(self) -> None:
2020         src = ["-"]
2021         stdin_filename = str(THIS_DIR / "data/collections.py")
2022         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2023         assert_collected_sources(
2024             src,
2025             expected,
2026             exclude=r"/exclude/a\.py",
2027             stdin_filename=stdin_filename,
2028         )
2029
2030     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2031     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2032         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2033         # file being passed directly. This is the same as
2034         # test_exclude_for_issue_1572
2035         path = DATA_DIR / "include_exclude_tests"
2036         src = ["-"]
2037         stdin_filename = str(path / "b/exclude/a.py")
2038         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2039         assert_collected_sources(
2040             src,
2041             expected,
2042             exclude=r"/exclude/|a\.py",
2043             stdin_filename=stdin_filename,
2044         )
2045
2046     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2047     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2048         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2049         # file being passed directly. This is the same as
2050         # test_exclude_for_issue_1572
2051         src = ["-"]
2052         path = THIS_DIR / "data" / "include_exclude_tests"
2053         stdin_filename = str(path / "b/exclude/a.py")
2054         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2055         assert_collected_sources(
2056             src,
2057             expected,
2058             extend_exclude=r"/exclude/|a\.py",
2059             stdin_filename=stdin_filename,
2060         )
2061
2062     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2063     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2064         # Force exclude should exclude the file when passing it through
2065         # stdin_filename
2066         path = THIS_DIR / "data" / "include_exclude_tests"
2067         stdin_filename = str(path / "b/exclude/a.py")
2068         assert_collected_sources(
2069             src=["-"],
2070             expected=[],
2071             force_exclude=r"/exclude/|a\.py",
2072             stdin_filename=stdin_filename,
2073         )
2074
2075
2076 try:
2077     with open(black.__file__, "r", encoding="utf-8") as _bf:
2078         black_source_lines = _bf.readlines()
2079 except UnicodeDecodeError:
2080     if not black.COMPILED:
2081         raise
2082
2083
2084 def tracefunc(
2085     frame: types.FrameType, event: str, arg: Any
2086 ) -> Callable[[types.FrameType, str, Any], Any]:
2087     """Show function calls `from black/__init__.py` as they happen.
2088
2089     Register this with `sys.settrace()` in a test you're debugging.
2090     """
2091     if event != "call":
2092         return tracefunc
2093
2094     stack = len(inspect.stack()) - 19
2095     stack *= 2
2096     filename = frame.f_code.co_filename
2097     lineno = frame.f_lineno
2098     func_sig_lineno = lineno - 1
2099     funcname = black_source_lines[func_sig_lineno].strip()
2100     while funcname.startswith("@"):
2101         func_sig_lineno += 1
2102         funcname = black_source_lines[func_sig_lineno].strip()
2103     if "black/__init__.py" in filename:
2104         print(f"{' ' * stack}{lineno}:{funcname}")
2105     return tracefunc