]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update link pointing to how-black-wraps-lines (#1925)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 from pathspec import PathSpec
38
39 # Import other test classes
40 from tests.util import (
41     THIS_DIR,
42     read_data,
43     DETERMINISTIC_HEADER,
44     BlackBaseTestCase,
45     DEFAULT_MODE,
46     fs,
47     ff,
48     dump_to_stderr,
49 )
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 THIS_FILE = Path(__file__)
54 PY36_VERSIONS = {
55     TargetVersion.PY36,
56     TargetVersion.PY37,
57     TargetVersion.PY38,
58     TargetVersion.PY39,
59 }
60 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
61 T = TypeVar("T")
62 R = TypeVar("R")
63
64
65 @contextmanager
66 def cache_dir(exists: bool = True) -> Iterator[Path]:
67     with TemporaryDirectory() as workspace:
68         cache_dir = Path(workspace)
69         if not exists:
70             cache_dir = cache_dir / "new"
71         with patch("black.CACHE_DIR", cache_dir):
72             yield cache_dir
73
74
75 @contextmanager
76 def event_loop() -> Iterator[None]:
77     policy = asyncio.get_event_loop_policy()
78     loop = policy.new_event_loop()
79     asyncio.set_event_loop(loop)
80     try:
81         yield
82
83     finally:
84         loop.close()
85
86
87 class FakeContext(click.Context):
88     """A fake click Context for when calling functions that need it."""
89
90     def __init__(self) -> None:
91         self.default_map: Dict[str, Any] = {}
92
93
94 class FakeParameter(click.Parameter):
95     """A fake click Parameter for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         pass
99
100
101 class BlackRunner(CliRunner):
102     """Modify CliRunner so that stderr is not merged with stdout.
103
104     This is a hack that can be removed once we depend on Click 7.x"""
105
106     def __init__(self) -> None:
107         self.stderrbuf = BytesIO()
108         self.stdoutbuf = BytesIO()
109         self.stdout_bytes = b""
110         self.stderr_bytes = b""
111         super().__init__()
112
113     @contextmanager
114     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
115         with super().isolation(*args, **kwargs) as output:
116             try:
117                 hold_stderr = sys.stderr
118                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
119                 yield output
120             finally:
121                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
122                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
123                 sys.stderr = hold_stderr
124
125
126 class BlackTestCase(BlackBaseTestCase):
127     def invokeBlack(
128         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
129     ) -> None:
130         runner = BlackRunner()
131         if ignore_config:
132             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
133         result = runner.invoke(black.main, args)
134         self.assertEqual(
135             result.exit_code,
136             exit_code,
137             msg=(
138                 f"Failed with args: {args}\n"
139                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
140                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
141                 f"exception: {result.exception}"
142             ),
143         )
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_empty(self) -> None:
147         source = expected = ""
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, DEFAULT_MODE)
152
153     def test_empty_ff(self) -> None:
154         expected = ""
155         tmp_file = Path(black.dump_to_file())
156         try:
157             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
158             with open(tmp_file, encoding="utf8") as f:
159                 actual = f.read()
160         finally:
161             os.unlink(tmp_file)
162         self.assertFormatEqual(expected, actual)
163
164     def test_piping(self) -> None:
165         source, expected = read_data("src/black/__init__", data=False)
166         result = BlackRunner().invoke(
167             black.main,
168             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         black.assert_equivalent(source, result.output)
174         black.assert_stable(source, result.output, DEFAULT_MODE)
175
176     def test_piping_diff(self) -> None:
177         diff_header = re.compile(
178             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
179             r"\+\d\d\d\d"
180         )
181         source, _ = read_data("expression.py")
182         expected, _ = read_data("expression.diff")
183         config = THIS_DIR / "data" / "empty_pyproject.toml"
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={config}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         config = THIS_DIR / "data" / "empty_pyproject.toml"
202         args = [
203             "-",
204             "--fast",
205             f"--line-length={black.DEFAULT_LINE_LENGTH}",
206             "--diff",
207             "--color",
208             f"--config={config}",
209         ]
210         result = BlackRunner().invoke(
211             black.main, args, input=BytesIO(source.encode("utf8"))
212         )
213         actual = result.output
214         # Again, the contents are checked in a different test, so only look for colors.
215         self.assertIn("\033[1;37m", actual)
216         self.assertIn("\033[36m", actual)
217         self.assertIn("\033[32m", actual)
218         self.assertIn("\033[31m", actual)
219         self.assertIn("\033[0m", actual)
220
221     @patch("black.dump_to_file", dump_to_stderr)
222     def _test_wip(self) -> None:
223         source, expected = read_data("wip")
224         sys.settrace(tracefunc)
225         mode = replace(
226             DEFAULT_MODE,
227             experimental_string_processing=False,
228             target_versions={black.TargetVersion.PY38},
229         )
230         actual = fs(source, mode=mode)
231         sys.settrace(None)
232         self.assertFormatEqual(expected, actual)
233         black.assert_equivalent(source, actual)
234         black.assert_stable(source, actual, black.FileMode())
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability1(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens1")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability2(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens2")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @unittest.expectedFailure
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability3(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens3")
254         actual = fs(source)
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_pep_572(self) -> None:
259         source, expected = read_data("pep_572")
260         actual = fs(source)
261         self.assertFormatEqual(expected, actual)
262         black.assert_stable(source, actual, DEFAULT_MODE)
263         if sys.version_info >= (3, 8):
264             black.assert_equivalent(source, actual)
265
266     def test_pep_572_version_detection(self) -> None:
267         source, _ = read_data("pep_572")
268         root = black.lib2to3_parse(source)
269         features = black.get_features_used(root)
270         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
271         versions = black.detect_target_versions(root)
272         self.assertIn(black.TargetVersion.PY38, versions)
273
274     def test_expression_ff(self) -> None:
275         source, expected = read_data("expression")
276         tmp_file = Path(black.dump_to_file(source))
277         try:
278             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
279             with open(tmp_file, encoding="utf8") as f:
280                 actual = f.read()
281         finally:
282             os.unlink(tmp_file)
283         self.assertFormatEqual(expected, actual)
284         with patch("black.dump_to_file", dump_to_stderr):
285             black.assert_equivalent(source, actual)
286             black.assert_stable(source, actual, DEFAULT_MODE)
287
288     def test_expression_diff(self) -> None:
289         source, _ = read_data("expression.py")
290         expected, _ = read_data("expression.diff")
291         tmp_file = Path(black.dump_to_file(source))
292         diff_header = re.compile(
293             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
294             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
295         )
296         try:
297             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
304         if expected != actual:
305             dump = black.dump_to_file(actual)
306             msg = (
307                 "Expected diff isn't equal to the actual. If you made changes to"
308                 " expression.py and this is an anticipated difference, overwrite"
309                 f" tests/data/expression.diff with {dump}"
310             )
311             self.assertEqual(expected, actual, msg)
312
313     def test_expression_diff_with_color(self) -> None:
314         source, _ = read_data("expression.py")
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file)]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_pep_570(self) -> None:
334         source, expected = read_data("pep_570")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_stable(source, actual, DEFAULT_MODE)
338         if sys.version_info >= (3, 8):
339             black.assert_equivalent(source, actual)
340
341     def test_detect_pos_only_arguments(self) -> None:
342         source, _ = read_data("pep_570")
343         root = black.lib2to3_parse(source)
344         features = black.get_features_used(root)
345         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
346         versions = black.detect_target_versions(root)
347         self.assertIn(black.TargetVersion.PY38, versions)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_string_quotes(self) -> None:
351         source, expected = read_data("string_quotes")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_equivalent(source, actual)
355         black.assert_stable(source, actual, DEFAULT_MODE)
356         mode = replace(DEFAULT_MODE, string_normalization=False)
357         not_normalized = fs(source, mode=mode)
358         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
359         black.assert_equivalent(source, not_normalized)
360         black.assert_stable(source, not_normalized, mode=mode)
361
362     @patch("black.dump_to_file", dump_to_stderr)
363     def test_docstring_no_string_normalization(self) -> None:
364         """Like test_docstring but with string normalization off."""
365         source, expected = read_data("docstring_no_string_normalization")
366         mode = replace(DEFAULT_MODE, string_normalization=False)
367         actual = fs(source, mode=mode)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, mode)
371
372     def test_long_strings_flag_disabled(self) -> None:
373         """Tests for turning off the string processing logic."""
374         source, expected = read_data("long_strings_flag_disabled")
375         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
376         actual = fs(source, mode=mode)
377         self.assertFormatEqual(expected, actual)
378         black.assert_stable(expected, actual, mode)
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_numeric_literals(self) -> None:
382         source, expected = read_data("numeric_literals")
383         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
384         actual = fs(source, mode=mode)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, mode)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_numeric_literals_ignoring_underscores(self) -> None:
391         source, expected = read_data("numeric_literals_skip_underscores")
392         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
393         actual = fs(source, mode=mode)
394         self.assertFormatEqual(expected, actual)
395         black.assert_equivalent(source, actual)
396         black.assert_stable(source, actual, mode)
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_python2_print_function(self) -> None:
400         source, expected = read_data("python2_print_function")
401         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
402         actual = fs(source, mode=mode)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, mode)
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_stub(self) -> None:
409         mode = replace(DEFAULT_MODE, is_pyi=True)
410         source, expected = read_data("stub.pyi")
411         actual = fs(source, mode=mode)
412         self.assertFormatEqual(expected, actual)
413         black.assert_stable(source, actual, mode)
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_async_as_identifier(self) -> None:
417         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
418         source, expected = read_data("async_as_identifier")
419         actual = fs(source)
420         self.assertFormatEqual(expected, actual)
421         major, minor = sys.version_info[:2]
422         if major < 3 or (major <= 3 and minor < 7):
423             black.assert_equivalent(source, actual)
424         black.assert_stable(source, actual, DEFAULT_MODE)
425         # ensure black can parse this when the target is 3.6
426         self.invokeBlack([str(source_path), "--target-version", "py36"])
427         # but not on 3.7, because async/await is no longer an identifier
428         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
429
430     @patch("black.dump_to_file", dump_to_stderr)
431     def test_python37(self) -> None:
432         source_path = (THIS_DIR / "data" / "python37.py").resolve()
433         source, expected = read_data("python37")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         major, minor = sys.version_info[:2]
437         if major > 3 or (major == 3 and minor >= 7):
438             black.assert_equivalent(source, actual)
439         black.assert_stable(source, actual, DEFAULT_MODE)
440         # ensure black can parse this when the target is 3.7
441         self.invokeBlack([str(source_path), "--target-version", "py37"])
442         # but not on 3.6, because we use async as a reserved keyword
443         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
444
445     @patch("black.dump_to_file", dump_to_stderr)
446     def test_python38(self) -> None:
447         source, expected = read_data("python38")
448         actual = fs(source)
449         self.assertFormatEqual(expected, actual)
450         major, minor = sys.version_info[:2]
451         if major > 3 or (major == 3 and minor >= 8):
452             black.assert_equivalent(source, actual)
453         black.assert_stable(source, actual, DEFAULT_MODE)
454
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_python39(self) -> None:
457         source, expected = read_data("python39")
458         actual = fs(source)
459         self.assertFormatEqual(expected, actual)
460         major, minor = sys.version_info[:2]
461         if major > 3 or (major == 3 and minor >= 9):
462             black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, DEFAULT_MODE)
464
465     def test_tab_comment_indentation(self) -> None:
466         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
467         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
468         self.assertFormatEqual(contents_spc, fs(contents_spc))
469         self.assertFormatEqual(contents_spc, fs(contents_tab))
470
471         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
472         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
473         self.assertFormatEqual(contents_spc, fs(contents_spc))
474         self.assertFormatEqual(contents_spc, fs(contents_tab))
475
476         # mixed tabs and spaces (valid Python 2 code)
477         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
478         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
479         self.assertFormatEqual(contents_spc, fs(contents_spc))
480         self.assertFormatEqual(contents_spc, fs(contents_tab))
481
482         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
483         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
484         self.assertFormatEqual(contents_spc, fs(contents_spc))
485         self.assertFormatEqual(contents_spc, fs(contents_tab))
486
487     def test_report_verbose(self) -> None:
488         report = black.Report(verbose=True)
489         out_lines = []
490         err_lines = []
491
492         def out(msg: str, **kwargs: Any) -> None:
493             out_lines.append(msg)
494
495         def err(msg: str, **kwargs: Any) -> None:
496             err_lines.append(msg)
497
498         with patch("black.out", out), patch("black.err", err):
499             report.done(Path("f1"), black.Changed.NO)
500             self.assertEqual(len(out_lines), 1)
501             self.assertEqual(len(err_lines), 0)
502             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
503             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
504             self.assertEqual(report.return_code, 0)
505             report.done(Path("f2"), black.Changed.YES)
506             self.assertEqual(len(out_lines), 2)
507             self.assertEqual(len(err_lines), 0)
508             self.assertEqual(out_lines[-1], "reformatted f2")
509             self.assertEqual(
510                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
511             )
512             report.done(Path("f3"), black.Changed.CACHED)
513             self.assertEqual(len(out_lines), 3)
514             self.assertEqual(len(err_lines), 0)
515             self.assertEqual(
516                 out_lines[-1], "f3 wasn't modified on disk since last run."
517             )
518             self.assertEqual(
519                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
520             )
521             self.assertEqual(report.return_code, 0)
522             report.check = True
523             self.assertEqual(report.return_code, 1)
524             report.check = False
525             report.failed(Path("e1"), "boom")
526             self.assertEqual(len(out_lines), 3)
527             self.assertEqual(len(err_lines), 1)
528             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
529             self.assertEqual(
530                 unstyle(str(report)),
531                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
532                 " reformat.",
533             )
534             self.assertEqual(report.return_code, 123)
535             report.done(Path("f3"), black.Changed.YES)
536             self.assertEqual(len(out_lines), 4)
537             self.assertEqual(len(err_lines), 1)
538             self.assertEqual(out_lines[-1], "reformatted f3")
539             self.assertEqual(
540                 unstyle(str(report)),
541                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
542                 " reformat.",
543             )
544             self.assertEqual(report.return_code, 123)
545             report.failed(Path("e2"), "boom")
546             self.assertEqual(len(out_lines), 4)
547             self.assertEqual(len(err_lines), 2)
548             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
549             self.assertEqual(
550                 unstyle(str(report)),
551                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
552                 " reformat.",
553             )
554             self.assertEqual(report.return_code, 123)
555             report.path_ignored(Path("wat"), "no match")
556             self.assertEqual(len(out_lines), 5)
557             self.assertEqual(len(err_lines), 2)
558             self.assertEqual(out_lines[-1], "wat ignored: no match")
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
562                 " reformat.",
563             )
564             self.assertEqual(report.return_code, 123)
565             report.done(Path("f4"), black.Changed.NO)
566             self.assertEqual(len(out_lines), 6)
567             self.assertEqual(len(err_lines), 2)
568             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
569             self.assertEqual(
570                 unstyle(str(report)),
571                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
572                 " reformat.",
573             )
574             self.assertEqual(report.return_code, 123)
575             report.check = True
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
579                 " would fail to reformat.",
580             )
581             report.check = False
582             report.diff = True
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
586                 " would fail to reformat.",
587             )
588
589     def test_report_quiet(self) -> None:
590         report = black.Report(quiet=True)
591         out_lines = []
592         err_lines = []
593
594         def out(msg: str, **kwargs: Any) -> None:
595             out_lines.append(msg)
596
597         def err(msg: str, **kwargs: Any) -> None:
598             err_lines.append(msg)
599
600         with patch("black.out", out), patch("black.err", err):
601             report.done(Path("f1"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 0)
604             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
605             self.assertEqual(report.return_code, 0)
606             report.done(Path("f2"), black.Changed.YES)
607             self.assertEqual(len(out_lines), 0)
608             self.assertEqual(len(err_lines), 0)
609             self.assertEqual(
610                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
611             )
612             report.done(Path("f3"), black.Changed.CACHED)
613             self.assertEqual(len(out_lines), 0)
614             self.assertEqual(len(err_lines), 0)
615             self.assertEqual(
616                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
617             )
618             self.assertEqual(report.return_code, 0)
619             report.check = True
620             self.assertEqual(report.return_code, 1)
621             report.check = False
622             report.failed(Path("e1"), "boom")
623             self.assertEqual(len(out_lines), 0)
624             self.assertEqual(len(err_lines), 1)
625             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
626             self.assertEqual(
627                 unstyle(str(report)),
628                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
629                 " reformat.",
630             )
631             self.assertEqual(report.return_code, 123)
632             report.done(Path("f3"), black.Changed.YES)
633             self.assertEqual(len(out_lines), 0)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(
636                 unstyle(str(report)),
637                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
638                 " reformat.",
639             )
640             self.assertEqual(report.return_code, 123)
641             report.failed(Path("e2"), "boom")
642             self.assertEqual(len(out_lines), 0)
643             self.assertEqual(len(err_lines), 2)
644             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
645             self.assertEqual(
646                 unstyle(str(report)),
647                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
648                 " reformat.",
649             )
650             self.assertEqual(report.return_code, 123)
651             report.path_ignored(Path("wat"), "no match")
652             self.assertEqual(len(out_lines), 0)
653             self.assertEqual(len(err_lines), 2)
654             self.assertEqual(
655                 unstyle(str(report)),
656                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
657                 " reformat.",
658             )
659             self.assertEqual(report.return_code, 123)
660             report.done(Path("f4"), black.Changed.NO)
661             self.assertEqual(len(out_lines), 0)
662             self.assertEqual(len(err_lines), 2)
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.check = True
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
673                 " would fail to reformat.",
674             )
675             report.check = False
676             report.diff = True
677             self.assertEqual(
678                 unstyle(str(report)),
679                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
680                 " would fail to reformat.",
681             )
682
683     def test_report_normal(self) -> None:
684         report = black.Report()
685         out_lines = []
686         err_lines = []
687
688         def out(msg: str, **kwargs: Any) -> None:
689             out_lines.append(msg)
690
691         def err(msg: str, **kwargs: Any) -> None:
692             err_lines.append(msg)
693
694         with patch("black.out", out), patch("black.err", err):
695             report.done(Path("f1"), black.Changed.NO)
696             self.assertEqual(len(out_lines), 0)
697             self.assertEqual(len(err_lines), 0)
698             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
699             self.assertEqual(report.return_code, 0)
700             report.done(Path("f2"), black.Changed.YES)
701             self.assertEqual(len(out_lines), 1)
702             self.assertEqual(len(err_lines), 0)
703             self.assertEqual(out_lines[-1], "reformatted f2")
704             self.assertEqual(
705                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
706             )
707             report.done(Path("f3"), black.Changed.CACHED)
708             self.assertEqual(len(out_lines), 1)
709             self.assertEqual(len(err_lines), 0)
710             self.assertEqual(out_lines[-1], "reformatted f2")
711             self.assertEqual(
712                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
713             )
714             self.assertEqual(report.return_code, 0)
715             report.check = True
716             self.assertEqual(report.return_code, 1)
717             report.check = False
718             report.failed(Path("e1"), "boom")
719             self.assertEqual(len(out_lines), 1)
720             self.assertEqual(len(err_lines), 1)
721             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
725                 " reformat.",
726             )
727             self.assertEqual(report.return_code, 123)
728             report.done(Path("f3"), black.Changed.YES)
729             self.assertEqual(len(out_lines), 2)
730             self.assertEqual(len(err_lines), 1)
731             self.assertEqual(out_lines[-1], "reformatted f3")
732             self.assertEqual(
733                 unstyle(str(report)),
734                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
735                 " reformat.",
736             )
737             self.assertEqual(report.return_code, 123)
738             report.failed(Path("e2"), "boom")
739             self.assertEqual(len(out_lines), 2)
740             self.assertEqual(len(err_lines), 2)
741             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
742             self.assertEqual(
743                 unstyle(str(report)),
744                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
745                 " reformat.",
746             )
747             self.assertEqual(report.return_code, 123)
748             report.path_ignored(Path("wat"), "no match")
749             self.assertEqual(len(out_lines), 2)
750             self.assertEqual(len(err_lines), 2)
751             self.assertEqual(
752                 unstyle(str(report)),
753                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
754                 " reformat.",
755             )
756             self.assertEqual(report.return_code, 123)
757             report.done(Path("f4"), black.Changed.NO)
758             self.assertEqual(len(out_lines), 2)
759             self.assertEqual(len(err_lines), 2)
760             self.assertEqual(
761                 unstyle(str(report)),
762                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
763                 " reformat.",
764             )
765             self.assertEqual(report.return_code, 123)
766             report.check = True
767             self.assertEqual(
768                 unstyle(str(report)),
769                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
770                 " would fail to reformat.",
771             )
772             report.check = False
773             report.diff = True
774             self.assertEqual(
775                 unstyle(str(report)),
776                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
777                 " would fail to reformat.",
778             )
779
780     def test_lib2to3_parse(self) -> None:
781         with self.assertRaises(black.InvalidInput):
782             black.lib2to3_parse("invalid syntax")
783
784         straddling = "x + y"
785         black.lib2to3_parse(straddling)
786         black.lib2to3_parse(straddling, {TargetVersion.PY27})
787         black.lib2to3_parse(straddling, {TargetVersion.PY36})
788         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
789
790         py2_only = "print x"
791         black.lib2to3_parse(py2_only)
792         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
793         with self.assertRaises(black.InvalidInput):
794             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
795         with self.assertRaises(black.InvalidInput):
796             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
797
798         py3_only = "exec(x, end=y)"
799         black.lib2to3_parse(py3_only)
800         with self.assertRaises(black.InvalidInput):
801             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
802         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
803         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
804
805     def test_get_features_used_decorator(self) -> None:
806         # Test the feature detection of new decorator syntax
807         # since this makes some test cases of test_get_features_used()
808         # fails if it fails, this is tested first so that a useful case
809         # is identified
810         simples, relaxed = read_data("decorators")
811         # skip explanation comments at the top of the file
812         for simple_test in simples.split("##")[1:]:
813             node = black.lib2to3_parse(simple_test)
814             decorator = str(node.children[0].children[0]).strip()
815             self.assertNotIn(
816                 Feature.RELAXED_DECORATORS,
817                 black.get_features_used(node),
818                 msg=(
819                     f"decorator '{decorator}' follows python<=3.8 syntax"
820                     "but is detected as 3.9+"
821                     # f"The full node is\n{node!r}"
822                 ),
823             )
824         # skip the '# output' comment at the top of the output part
825         for relaxed_test in relaxed.split("##")[1:]:
826             node = black.lib2to3_parse(relaxed_test)
827             decorator = str(node.children[0].children[0]).strip()
828             self.assertIn(
829                 Feature.RELAXED_DECORATORS,
830                 black.get_features_used(node),
831                 msg=(
832                     f"decorator '{decorator}' uses python3.9+ syntax"
833                     "but is detected as python<=3.8"
834                     # f"The full node is\n{node!r}"
835                 ),
836             )
837
838     def test_get_features_used(self) -> None:
839         node = black.lib2to3_parse("def f(*, arg): ...\n")
840         self.assertEqual(black.get_features_used(node), set())
841         node = black.lib2to3_parse("def f(*, arg,): ...\n")
842         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
843         node = black.lib2to3_parse("f(*arg,)\n")
844         self.assertEqual(
845             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
846         )
847         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
848         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
849         node = black.lib2to3_parse("123_456\n")
850         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
851         node = black.lib2to3_parse("123456\n")
852         self.assertEqual(black.get_features_used(node), set())
853         source, expected = read_data("function")
854         node = black.lib2to3_parse(source)
855         expected_features = {
856             Feature.TRAILING_COMMA_IN_CALL,
857             Feature.TRAILING_COMMA_IN_DEF,
858             Feature.F_STRINGS,
859         }
860         self.assertEqual(black.get_features_used(node), expected_features)
861         node = black.lib2to3_parse(expected)
862         self.assertEqual(black.get_features_used(node), expected_features)
863         source, expected = read_data("expression")
864         node = black.lib2to3_parse(source)
865         self.assertEqual(black.get_features_used(node), set())
866         node = black.lib2to3_parse(expected)
867         self.assertEqual(black.get_features_used(node), set())
868
869     def test_get_future_imports(self) -> None:
870         node = black.lib2to3_parse("\n")
871         self.assertEqual(set(), black.get_future_imports(node))
872         node = black.lib2to3_parse("from __future__ import black\n")
873         self.assertEqual({"black"}, black.get_future_imports(node))
874         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
875         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
876         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
877         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
878         node = black.lib2to3_parse(
879             "from __future__ import multiple\nfrom __future__ import imports\n"
880         )
881         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
882         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
883         self.assertEqual({"black"}, black.get_future_imports(node))
884         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
885         self.assertEqual({"black"}, black.get_future_imports(node))
886         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
887         self.assertEqual(set(), black.get_future_imports(node))
888         node = black.lib2to3_parse("from some.module import black\n")
889         self.assertEqual(set(), black.get_future_imports(node))
890         node = black.lib2to3_parse(
891             "from __future__ import unicode_literals as _unicode_literals"
892         )
893         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
894         node = black.lib2to3_parse(
895             "from __future__ import unicode_literals as _lol, print"
896         )
897         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
898
899     def test_debug_visitor(self) -> None:
900         source, _ = read_data("debug_visitor.py")
901         expected, _ = read_data("debug_visitor.out")
902         out_lines = []
903         err_lines = []
904
905         def out(msg: str, **kwargs: Any) -> None:
906             out_lines.append(msg)
907
908         def err(msg: str, **kwargs: Any) -> None:
909             err_lines.append(msg)
910
911         with patch("black.out", out), patch("black.err", err):
912             black.DebugVisitor.show(source)
913         actual = "\n".join(out_lines) + "\n"
914         log_name = ""
915         if expected != actual:
916             log_name = black.dump_to_file(*out_lines)
917         self.assertEqual(
918             expected,
919             actual,
920             f"AST print out is different. Actual version dumped to {log_name}",
921         )
922
923     def test_format_file_contents(self) -> None:
924         empty = ""
925         mode = DEFAULT_MODE
926         with self.assertRaises(black.NothingChanged):
927             black.format_file_contents(empty, mode=mode, fast=False)
928         just_nl = "\n"
929         with self.assertRaises(black.NothingChanged):
930             black.format_file_contents(just_nl, mode=mode, fast=False)
931         same = "j = [1, 2, 3]\n"
932         with self.assertRaises(black.NothingChanged):
933             black.format_file_contents(same, mode=mode, fast=False)
934         different = "j = [1,2,3]"
935         expected = same
936         actual = black.format_file_contents(different, mode=mode, fast=False)
937         self.assertEqual(expected, actual)
938         invalid = "return if you can"
939         with self.assertRaises(black.InvalidInput) as e:
940             black.format_file_contents(invalid, mode=mode, fast=False)
941         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
942
943     def test_endmarker(self) -> None:
944         n = black.lib2to3_parse("\n")
945         self.assertEqual(n.type, black.syms.file_input)
946         self.assertEqual(len(n.children), 1)
947         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
948
949     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
950     def test_assertFormatEqual(self) -> None:
951         out_lines = []
952         err_lines = []
953
954         def out(msg: str, **kwargs: Any) -> None:
955             out_lines.append(msg)
956
957         def err(msg: str, **kwargs: Any) -> None:
958             err_lines.append(msg)
959
960         with patch("black.out", out), patch("black.err", err):
961             with self.assertRaises(AssertionError):
962                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
963
964         out_str = "".join(out_lines)
965         self.assertTrue("Expected tree:" in out_str)
966         self.assertTrue("Actual tree:" in out_str)
967         self.assertEqual("".join(err_lines), "")
968
969     def test_cache_broken_file(self) -> None:
970         mode = DEFAULT_MODE
971         with cache_dir() as workspace:
972             cache_file = black.get_cache_file(mode)
973             with cache_file.open("w") as fobj:
974                 fobj.write("this is not a pickle")
975             self.assertEqual(black.read_cache(mode), {})
976             src = (workspace / "test.py").resolve()
977             with src.open("w") as fobj:
978                 fobj.write("print('hello')")
979             self.invokeBlack([str(src)])
980             cache = black.read_cache(mode)
981             self.assertIn(src, cache)
982
983     def test_cache_single_file_already_cached(self) -> None:
984         mode = DEFAULT_MODE
985         with cache_dir() as workspace:
986             src = (workspace / "test.py").resolve()
987             with src.open("w") as fobj:
988                 fobj.write("print('hello')")
989             black.write_cache({}, [src], mode)
990             self.invokeBlack([str(src)])
991             with src.open("r") as fobj:
992                 self.assertEqual(fobj.read(), "print('hello')")
993
994     @event_loop()
995     def test_cache_multiple_files(self) -> None:
996         mode = DEFAULT_MODE
997         with cache_dir() as workspace, patch(
998             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
999         ):
1000             one = (workspace / "one.py").resolve()
1001             with one.open("w") as fobj:
1002                 fobj.write("print('hello')")
1003             two = (workspace / "two.py").resolve()
1004             with two.open("w") as fobj:
1005                 fobj.write("print('hello')")
1006             black.write_cache({}, [one], mode)
1007             self.invokeBlack([str(workspace)])
1008             with one.open("r") as fobj:
1009                 self.assertEqual(fobj.read(), "print('hello')")
1010             with two.open("r") as fobj:
1011                 self.assertEqual(fobj.read(), 'print("hello")\n')
1012             cache = black.read_cache(mode)
1013             self.assertIn(one, cache)
1014             self.assertIn(two, cache)
1015
1016     def test_no_cache_when_writeback_diff(self) -> None:
1017         mode = DEFAULT_MODE
1018         with cache_dir() as workspace:
1019             src = (workspace / "test.py").resolve()
1020             with src.open("w") as fobj:
1021                 fobj.write("print('hello')")
1022             with patch("black.read_cache") as read_cache, patch(
1023                 "black.write_cache"
1024             ) as write_cache:
1025                 self.invokeBlack([str(src), "--diff"])
1026                 cache_file = black.get_cache_file(mode)
1027                 self.assertFalse(cache_file.exists())
1028                 write_cache.assert_not_called()
1029                 read_cache.assert_not_called()
1030
1031     def test_no_cache_when_writeback_color_diff(self) -> None:
1032         mode = DEFAULT_MODE
1033         with cache_dir() as workspace:
1034             src = (workspace / "test.py").resolve()
1035             with src.open("w") as fobj:
1036                 fobj.write("print('hello')")
1037             with patch("black.read_cache") as read_cache, patch(
1038                 "black.write_cache"
1039             ) as write_cache:
1040                 self.invokeBlack([str(src), "--diff", "--color"])
1041                 cache_file = black.get_cache_file(mode)
1042                 self.assertFalse(cache_file.exists())
1043                 write_cache.assert_not_called()
1044                 read_cache.assert_not_called()
1045
1046     @event_loop()
1047     def test_output_locking_when_writeback_diff(self) -> None:
1048         with cache_dir() as workspace:
1049             for tag in range(0, 4):
1050                 src = (workspace / f"test{tag}.py").resolve()
1051                 with src.open("w") as fobj:
1052                     fobj.write("print('hello')")
1053             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1054                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1055                 # this isn't quite doing what we want, but if it _isn't_
1056                 # called then we cannot be using the lock it provides
1057                 mgr.assert_called()
1058
1059     @event_loop()
1060     def test_output_locking_when_writeback_color_diff(self) -> None:
1061         with cache_dir() as workspace:
1062             for tag in range(0, 4):
1063                 src = (workspace / f"test{tag}.py").resolve()
1064                 with src.open("w") as fobj:
1065                     fobj.write("print('hello')")
1066             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1067                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1068                 # this isn't quite doing what we want, but if it _isn't_
1069                 # called then we cannot be using the lock it provides
1070                 mgr.assert_called()
1071
1072     def test_no_cache_when_stdin(self) -> None:
1073         mode = DEFAULT_MODE
1074         with cache_dir():
1075             result = CliRunner().invoke(
1076                 black.main, ["-"], input=BytesIO(b"print('hello')")
1077             )
1078             self.assertEqual(result.exit_code, 0)
1079             cache_file = black.get_cache_file(mode)
1080             self.assertFalse(cache_file.exists())
1081
1082     def test_read_cache_no_cachefile(self) -> None:
1083         mode = DEFAULT_MODE
1084         with cache_dir():
1085             self.assertEqual(black.read_cache(mode), {})
1086
1087     def test_write_cache_read_cache(self) -> None:
1088         mode = DEFAULT_MODE
1089         with cache_dir() as workspace:
1090             src = (workspace / "test.py").resolve()
1091             src.touch()
1092             black.write_cache({}, [src], mode)
1093             cache = black.read_cache(mode)
1094             self.assertIn(src, cache)
1095             self.assertEqual(cache[src], black.get_cache_info(src))
1096
1097     def test_filter_cached(self) -> None:
1098         with TemporaryDirectory() as workspace:
1099             path = Path(workspace)
1100             uncached = (path / "uncached").resolve()
1101             cached = (path / "cached").resolve()
1102             cached_but_changed = (path / "changed").resolve()
1103             uncached.touch()
1104             cached.touch()
1105             cached_but_changed.touch()
1106             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1107             todo, done = black.filter_cached(
1108                 cache, {uncached, cached, cached_but_changed}
1109             )
1110             self.assertEqual(todo, {uncached, cached_but_changed})
1111             self.assertEqual(done, {cached})
1112
1113     def test_write_cache_creates_directory_if_needed(self) -> None:
1114         mode = DEFAULT_MODE
1115         with cache_dir(exists=False) as workspace:
1116             self.assertFalse(workspace.exists())
1117             black.write_cache({}, [], mode)
1118             self.assertTrue(workspace.exists())
1119
1120     @event_loop()
1121     def test_failed_formatting_does_not_get_cached(self) -> None:
1122         mode = DEFAULT_MODE
1123         with cache_dir() as workspace, patch(
1124             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1125         ):
1126             failing = (workspace / "failing.py").resolve()
1127             with failing.open("w") as fobj:
1128                 fobj.write("not actually python")
1129             clean = (workspace / "clean.py").resolve()
1130             with clean.open("w") as fobj:
1131                 fobj.write('print("hello")\n')
1132             self.invokeBlack([str(workspace)], exit_code=123)
1133             cache = black.read_cache(mode)
1134             self.assertNotIn(failing, cache)
1135             self.assertIn(clean, cache)
1136
1137     def test_write_cache_write_fail(self) -> None:
1138         mode = DEFAULT_MODE
1139         with cache_dir(), patch.object(Path, "open") as mock:
1140             mock.side_effect = OSError
1141             black.write_cache({}, [], mode)
1142
1143     @event_loop()
1144     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1145     def test_works_in_mono_process_only_environment(self) -> None:
1146         with cache_dir() as workspace:
1147             for f in [
1148                 (workspace / "one.py").resolve(),
1149                 (workspace / "two.py").resolve(),
1150             ]:
1151                 f.write_text('print("hello")\n')
1152             self.invokeBlack([str(workspace)])
1153
1154     @event_loop()
1155     def test_check_diff_use_together(self) -> None:
1156         with cache_dir():
1157             # Files which will be reformatted.
1158             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1159             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1160             # Files which will not be reformatted.
1161             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1162             self.invokeBlack([str(src2), "--diff", "--check"])
1163             # Multi file command.
1164             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1165
1166     def test_no_files(self) -> None:
1167         with cache_dir():
1168             # Without an argument, black exits with error code 0.
1169             self.invokeBlack([])
1170
1171     def test_broken_symlink(self) -> None:
1172         with cache_dir() as workspace:
1173             symlink = workspace / "broken_link.py"
1174             try:
1175                 symlink.symlink_to("nonexistent.py")
1176             except OSError as e:
1177                 self.skipTest(f"Can't create symlinks: {e}")
1178             self.invokeBlack([str(workspace.resolve())])
1179
1180     def test_read_cache_line_lengths(self) -> None:
1181         mode = DEFAULT_MODE
1182         short_mode = replace(DEFAULT_MODE, line_length=1)
1183         with cache_dir() as workspace:
1184             path = (workspace / "file.py").resolve()
1185             path.touch()
1186             black.write_cache({}, [path], mode)
1187             one = black.read_cache(mode)
1188             self.assertIn(path, one)
1189             two = black.read_cache(short_mode)
1190             self.assertNotIn(path, two)
1191
1192     def test_single_file_force_pyi(self) -> None:
1193         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1194         contents, expected = read_data("force_pyi")
1195         with cache_dir() as workspace:
1196             path = (workspace / "file.py").resolve()
1197             with open(path, "w") as fh:
1198                 fh.write(contents)
1199             self.invokeBlack([str(path), "--pyi"])
1200             with open(path, "r") as fh:
1201                 actual = fh.read()
1202             # verify cache with --pyi is separate
1203             pyi_cache = black.read_cache(pyi_mode)
1204             self.assertIn(path, pyi_cache)
1205             normal_cache = black.read_cache(DEFAULT_MODE)
1206             self.assertNotIn(path, normal_cache)
1207         self.assertFormatEqual(expected, actual)
1208         black.assert_equivalent(contents, actual)
1209         black.assert_stable(contents, actual, pyi_mode)
1210
1211     @event_loop()
1212     def test_multi_file_force_pyi(self) -> None:
1213         reg_mode = DEFAULT_MODE
1214         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1215         contents, expected = read_data("force_pyi")
1216         with cache_dir() as workspace:
1217             paths = [
1218                 (workspace / "file1.py").resolve(),
1219                 (workspace / "file2.py").resolve(),
1220             ]
1221             for path in paths:
1222                 with open(path, "w") as fh:
1223                     fh.write(contents)
1224             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1225             for path in paths:
1226                 with open(path, "r") as fh:
1227                     actual = fh.read()
1228                 self.assertEqual(actual, expected)
1229             # verify cache with --pyi is separate
1230             pyi_cache = black.read_cache(pyi_mode)
1231             normal_cache = black.read_cache(reg_mode)
1232             for path in paths:
1233                 self.assertIn(path, pyi_cache)
1234                 self.assertNotIn(path, normal_cache)
1235
1236     def test_pipe_force_pyi(self) -> None:
1237         source, expected = read_data("force_pyi")
1238         result = CliRunner().invoke(
1239             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1240         )
1241         self.assertEqual(result.exit_code, 0)
1242         actual = result.output
1243         self.assertFormatEqual(actual, expected)
1244
1245     def test_single_file_force_py36(self) -> None:
1246         reg_mode = DEFAULT_MODE
1247         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1248         source, expected = read_data("force_py36")
1249         with cache_dir() as workspace:
1250             path = (workspace / "file.py").resolve()
1251             with open(path, "w") as fh:
1252                 fh.write(source)
1253             self.invokeBlack([str(path), *PY36_ARGS])
1254             with open(path, "r") as fh:
1255                 actual = fh.read()
1256             # verify cache with --target-version is separate
1257             py36_cache = black.read_cache(py36_mode)
1258             self.assertIn(path, py36_cache)
1259             normal_cache = black.read_cache(reg_mode)
1260             self.assertNotIn(path, normal_cache)
1261         self.assertEqual(actual, expected)
1262
1263     @event_loop()
1264     def test_multi_file_force_py36(self) -> None:
1265         reg_mode = DEFAULT_MODE
1266         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1267         source, expected = read_data("force_py36")
1268         with cache_dir() as workspace:
1269             paths = [
1270                 (workspace / "file1.py").resolve(),
1271                 (workspace / "file2.py").resolve(),
1272             ]
1273             for path in paths:
1274                 with open(path, "w") as fh:
1275                     fh.write(source)
1276             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1277             for path in paths:
1278                 with open(path, "r") as fh:
1279                     actual = fh.read()
1280                 self.assertEqual(actual, expected)
1281             # verify cache with --target-version is separate
1282             pyi_cache = black.read_cache(py36_mode)
1283             normal_cache = black.read_cache(reg_mode)
1284             for path in paths:
1285                 self.assertIn(path, pyi_cache)
1286                 self.assertNotIn(path, normal_cache)
1287
1288     def test_pipe_force_py36(self) -> None:
1289         source, expected = read_data("force_py36")
1290         result = CliRunner().invoke(
1291             black.main,
1292             ["-", "-q", "--target-version=py36"],
1293             input=BytesIO(source.encode("utf8")),
1294         )
1295         self.assertEqual(result.exit_code, 0)
1296         actual = result.output
1297         self.assertFormatEqual(actual, expected)
1298
1299     def test_include_exclude(self) -> None:
1300         path = THIS_DIR / "data" / "include_exclude_tests"
1301         include = re.compile(r"\.pyi?$")
1302         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1303         report = black.Report()
1304         gitignore = PathSpec.from_lines("gitwildmatch", [])
1305         sources: List[Path] = []
1306         expected = [
1307             Path(path / "b/dont_exclude/a.py"),
1308             Path(path / "b/dont_exclude/a.pyi"),
1309         ]
1310         this_abs = THIS_DIR.resolve()
1311         sources.extend(
1312             black.gen_python_files(
1313                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1314             )
1315         )
1316         self.assertEqual(sorted(expected), sorted(sources))
1317
1318     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1319     def test_exclude_for_issue_1572(self) -> None:
1320         # Exclude shouldn't touch files that were explicitly given to Black through the
1321         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1322         # https://github.com/psf/black/issues/1572
1323         path = THIS_DIR / "data" / "include_exclude_tests"
1324         include = ""
1325         exclude = r"/exclude/|a\.py"
1326         src = str(path / "b/exclude/a.py")
1327         report = black.Report()
1328         expected = [Path(path / "b/exclude/a.py")]
1329         sources = list(
1330             black.get_sources(
1331                 ctx=FakeContext(),
1332                 src=(src,),
1333                 quiet=True,
1334                 verbose=False,
1335                 include=include,
1336                 exclude=exclude,
1337                 force_exclude=None,
1338                 report=report,
1339                 stdin_filename=None,
1340             )
1341         )
1342         self.assertEqual(sorted(expected), sorted(sources))
1343
1344     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1345     def test_get_sources_with_stdin(self) -> None:
1346         include = ""
1347         exclude = r"/exclude/|a\.py"
1348         src = "-"
1349         report = black.Report()
1350         expected = [Path("-")]
1351         sources = list(
1352             black.get_sources(
1353                 ctx=FakeContext(),
1354                 src=(src,),
1355                 quiet=True,
1356                 verbose=False,
1357                 include=include,
1358                 exclude=exclude,
1359                 force_exclude=None,
1360                 report=report,
1361                 stdin_filename=None,
1362             )
1363         )
1364         self.assertEqual(sorted(expected), sorted(sources))
1365
1366     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1367     def test_get_sources_with_stdin_filename(self) -> None:
1368         include = ""
1369         exclude = r"/exclude/|a\.py"
1370         src = "-"
1371         report = black.Report()
1372         stdin_filename = str(THIS_DIR / "data/collections.py")
1373         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1374         sources = list(
1375             black.get_sources(
1376                 ctx=FakeContext(),
1377                 src=(src,),
1378                 quiet=True,
1379                 verbose=False,
1380                 include=include,
1381                 exclude=exclude,
1382                 force_exclude=None,
1383                 report=report,
1384                 stdin_filename=stdin_filename,
1385             )
1386         )
1387         self.assertEqual(sorted(expected), sorted(sources))
1388
1389     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1390     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1391         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1392         # file being passed directly. This is the same as
1393         # test_exclude_for_issue_1572
1394         path = THIS_DIR / "data" / "include_exclude_tests"
1395         include = ""
1396         exclude = r"/exclude/|a\.py"
1397         src = "-"
1398         report = black.Report()
1399         stdin_filename = str(path / "b/exclude/a.py")
1400         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1401         sources = list(
1402             black.get_sources(
1403                 ctx=FakeContext(),
1404                 src=(src,),
1405                 quiet=True,
1406                 verbose=False,
1407                 include=include,
1408                 exclude=exclude,
1409                 force_exclude=None,
1410                 report=report,
1411                 stdin_filename=stdin_filename,
1412             )
1413         )
1414         self.assertEqual(sorted(expected), sorted(sources))
1415
1416     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1417     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1418         # Force exclude should exclude the file when passing it through
1419         # stdin_filename
1420         path = THIS_DIR / "data" / "include_exclude_tests"
1421         include = ""
1422         force_exclude = r"/exclude/|a\.py"
1423         src = "-"
1424         report = black.Report()
1425         stdin_filename = str(path / "b/exclude/a.py")
1426         sources = list(
1427             black.get_sources(
1428                 ctx=FakeContext(),
1429                 src=(src,),
1430                 quiet=True,
1431                 verbose=False,
1432                 include=include,
1433                 exclude="",
1434                 force_exclude=force_exclude,
1435                 report=report,
1436                 stdin_filename=stdin_filename,
1437             )
1438         )
1439         self.assertEqual([], sorted(sources))
1440
1441     def test_reformat_one_with_stdin(self) -> None:
1442         with patch(
1443             "black.format_stdin_to_stdout",
1444             return_value=lambda *args, **kwargs: black.Changed.YES,
1445         ) as fsts:
1446             report = MagicMock()
1447             path = Path("-")
1448             black.reformat_one(
1449                 path,
1450                 fast=True,
1451                 write_back=black.WriteBack.YES,
1452                 mode=DEFAULT_MODE,
1453                 report=report,
1454             )
1455             fsts.assert_called_once()
1456             report.done.assert_called_with(path, black.Changed.YES)
1457
1458     def test_reformat_one_with_stdin_filename(self) -> None:
1459         with patch(
1460             "black.format_stdin_to_stdout",
1461             return_value=lambda *args, **kwargs: black.Changed.YES,
1462         ) as fsts:
1463             report = MagicMock()
1464             p = "foo.py"
1465             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1466             expected = Path(p)
1467             black.reformat_one(
1468                 path,
1469                 fast=True,
1470                 write_back=black.WriteBack.YES,
1471                 mode=DEFAULT_MODE,
1472                 report=report,
1473             )
1474             fsts.assert_called_once()
1475             # __BLACK_STDIN_FILENAME__ should have been striped
1476             report.done.assert_called_with(expected, black.Changed.YES)
1477
1478     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1479         with patch(
1480             "black.format_stdin_to_stdout",
1481             return_value=lambda *args, **kwargs: black.Changed.YES,
1482         ) as fsts:
1483             report = MagicMock()
1484             # Even with an existing file, since we are forcing stdin, black
1485             # should output to stdout and not modify the file inplace
1486             p = Path(str(THIS_DIR / "data/collections.py"))
1487             # Make sure is_file actually returns True
1488             self.assertTrue(p.is_file())
1489             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1490             expected = Path(p)
1491             black.reformat_one(
1492                 path,
1493                 fast=True,
1494                 write_back=black.WriteBack.YES,
1495                 mode=DEFAULT_MODE,
1496                 report=report,
1497             )
1498             fsts.assert_called_once()
1499             # __BLACK_STDIN_FILENAME__ should have been striped
1500             report.done.assert_called_with(expected, black.Changed.YES)
1501
1502     def test_gitignore_exclude(self) -> None:
1503         path = THIS_DIR / "data" / "include_exclude_tests"
1504         include = re.compile(r"\.pyi?$")
1505         exclude = re.compile(r"")
1506         report = black.Report()
1507         gitignore = PathSpec.from_lines(
1508             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1509         )
1510         sources: List[Path] = []
1511         expected = [
1512             Path(path / "b/dont_exclude/a.py"),
1513             Path(path / "b/dont_exclude/a.pyi"),
1514         ]
1515         this_abs = THIS_DIR.resolve()
1516         sources.extend(
1517             black.gen_python_files(
1518                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1519             )
1520         )
1521         self.assertEqual(sorted(expected), sorted(sources))
1522
1523     def test_empty_include(self) -> None:
1524         path = THIS_DIR / "data" / "include_exclude_tests"
1525         report = black.Report()
1526         gitignore = PathSpec.from_lines("gitwildmatch", [])
1527         empty = re.compile(r"")
1528         sources: List[Path] = []
1529         expected = [
1530             Path(path / "b/exclude/a.pie"),
1531             Path(path / "b/exclude/a.py"),
1532             Path(path / "b/exclude/a.pyi"),
1533             Path(path / "b/dont_exclude/a.pie"),
1534             Path(path / "b/dont_exclude/a.py"),
1535             Path(path / "b/dont_exclude/a.pyi"),
1536             Path(path / "b/.definitely_exclude/a.pie"),
1537             Path(path / "b/.definitely_exclude/a.py"),
1538             Path(path / "b/.definitely_exclude/a.pyi"),
1539         ]
1540         this_abs = THIS_DIR.resolve()
1541         sources.extend(
1542             black.gen_python_files(
1543                 path.iterdir(),
1544                 this_abs,
1545                 empty,
1546                 re.compile(black.DEFAULT_EXCLUDES),
1547                 None,
1548                 report,
1549                 gitignore,
1550             )
1551         )
1552         self.assertEqual(sorted(expected), sorted(sources))
1553
1554     def test_empty_exclude(self) -> None:
1555         path = THIS_DIR / "data" / "include_exclude_tests"
1556         report = black.Report()
1557         gitignore = PathSpec.from_lines("gitwildmatch", [])
1558         empty = re.compile(r"")
1559         sources: List[Path] = []
1560         expected = [
1561             Path(path / "b/dont_exclude/a.py"),
1562             Path(path / "b/dont_exclude/a.pyi"),
1563             Path(path / "b/exclude/a.py"),
1564             Path(path / "b/exclude/a.pyi"),
1565             Path(path / "b/.definitely_exclude/a.py"),
1566             Path(path / "b/.definitely_exclude/a.pyi"),
1567         ]
1568         this_abs = THIS_DIR.resolve()
1569         sources.extend(
1570             black.gen_python_files(
1571                 path.iterdir(),
1572                 this_abs,
1573                 re.compile(black.DEFAULT_INCLUDES),
1574                 empty,
1575                 None,
1576                 report,
1577                 gitignore,
1578             )
1579         )
1580         self.assertEqual(sorted(expected), sorted(sources))
1581
1582     def test_invalid_include_exclude(self) -> None:
1583         for option in ["--include", "--exclude"]:
1584             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1585
1586     def test_preserves_line_endings(self) -> None:
1587         with TemporaryDirectory() as workspace:
1588             test_file = Path(workspace) / "test.py"
1589             for nl in ["\n", "\r\n"]:
1590                 contents = nl.join(["def f(  ):", "    pass"])
1591                 test_file.write_bytes(contents.encode())
1592                 ff(test_file, write_back=black.WriteBack.YES)
1593                 updated_contents: bytes = test_file.read_bytes()
1594                 self.assertIn(nl.encode(), updated_contents)
1595                 if nl == "\n":
1596                     self.assertNotIn(b"\r\n", updated_contents)
1597
1598     def test_preserves_line_endings_via_stdin(self) -> None:
1599         for nl in ["\n", "\r\n"]:
1600             contents = nl.join(["def f(  ):", "    pass"])
1601             runner = BlackRunner()
1602             result = runner.invoke(
1603                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1604             )
1605             self.assertEqual(result.exit_code, 0)
1606             output = runner.stdout_bytes
1607             self.assertIn(nl.encode("utf8"), output)
1608             if nl == "\n":
1609                 self.assertNotIn(b"\r\n", output)
1610
1611     def test_assert_equivalent_different_asts(self) -> None:
1612         with self.assertRaises(AssertionError):
1613             black.assert_equivalent("{}", "None")
1614
1615     def test_symlink_out_of_root_directory(self) -> None:
1616         path = MagicMock()
1617         root = THIS_DIR.resolve()
1618         child = MagicMock()
1619         include = re.compile(black.DEFAULT_INCLUDES)
1620         exclude = re.compile(black.DEFAULT_EXCLUDES)
1621         report = black.Report()
1622         gitignore = PathSpec.from_lines("gitwildmatch", [])
1623         # `child` should behave like a symlink which resolved path is clearly
1624         # outside of the `root` directory.
1625         path.iterdir.return_value = [child]
1626         child.resolve.return_value = Path("/a/b/c")
1627         child.as_posix.return_value = "/a/b/c"
1628         child.is_symlink.return_value = True
1629         try:
1630             list(
1631                 black.gen_python_files(
1632                     path.iterdir(), root, include, exclude, None, report, gitignore
1633                 )
1634             )
1635         except ValueError as ve:
1636             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1637         path.iterdir.assert_called_once()
1638         child.resolve.assert_called_once()
1639         child.is_symlink.assert_called_once()
1640         # `child` should behave like a strange file which resolved path is clearly
1641         # outside of the `root` directory.
1642         child.is_symlink.return_value = False
1643         with self.assertRaises(ValueError):
1644             list(
1645                 black.gen_python_files(
1646                     path.iterdir(), root, include, exclude, None, report, gitignore
1647                 )
1648             )
1649         path.iterdir.assert_called()
1650         self.assertEqual(path.iterdir.call_count, 2)
1651         child.resolve.assert_called()
1652         self.assertEqual(child.resolve.call_count, 2)
1653         child.is_symlink.assert_called()
1654         self.assertEqual(child.is_symlink.call_count, 2)
1655
1656     def test_shhh_click(self) -> None:
1657         try:
1658             from click import _unicodefun  # type: ignore
1659         except ModuleNotFoundError:
1660             self.skipTest("Incompatible Click version")
1661         if not hasattr(_unicodefun, "_verify_python3_env"):
1662             self.skipTest("Incompatible Click version")
1663         # First, let's see if Click is crashing with a preferred ASCII charset.
1664         with patch("locale.getpreferredencoding") as gpe:
1665             gpe.return_value = "ASCII"
1666             with self.assertRaises(RuntimeError):
1667                 _unicodefun._verify_python3_env()
1668         # Now, let's silence Click...
1669         black.patch_click()
1670         # ...and confirm it's silent.
1671         with patch("locale.getpreferredencoding") as gpe:
1672             gpe.return_value = "ASCII"
1673             try:
1674                 _unicodefun._verify_python3_env()
1675             except RuntimeError as re:
1676                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1677
1678     def test_root_logger_not_used_directly(self) -> None:
1679         def fail(*args: Any, **kwargs: Any) -> None:
1680             self.fail("Record created with root logger")
1681
1682         with patch.multiple(
1683             logging.root,
1684             debug=fail,
1685             info=fail,
1686             warning=fail,
1687             error=fail,
1688             critical=fail,
1689             log=fail,
1690         ):
1691             ff(THIS_FILE)
1692
1693     def test_invalid_config_return_code(self) -> None:
1694         tmp_file = Path(black.dump_to_file())
1695         try:
1696             tmp_config = Path(black.dump_to_file())
1697             tmp_config.unlink()
1698             args = ["--config", str(tmp_config), str(tmp_file)]
1699             self.invokeBlack(args, exit_code=2, ignore_config=False)
1700         finally:
1701             tmp_file.unlink()
1702
1703     def test_parse_pyproject_toml(self) -> None:
1704         test_toml_file = THIS_DIR / "test.toml"
1705         config = black.parse_pyproject_toml(str(test_toml_file))
1706         self.assertEqual(config["verbose"], 1)
1707         self.assertEqual(config["check"], "no")
1708         self.assertEqual(config["diff"], "y")
1709         self.assertEqual(config["color"], True)
1710         self.assertEqual(config["line_length"], 79)
1711         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1712         self.assertEqual(config["exclude"], r"\.pyi?$")
1713         self.assertEqual(config["include"], r"\.py?$")
1714
1715     def test_read_pyproject_toml(self) -> None:
1716         test_toml_file = THIS_DIR / "test.toml"
1717         fake_ctx = FakeContext()
1718         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1719         config = fake_ctx.default_map
1720         self.assertEqual(config["verbose"], "1")
1721         self.assertEqual(config["check"], "no")
1722         self.assertEqual(config["diff"], "y")
1723         self.assertEqual(config["color"], "True")
1724         self.assertEqual(config["line_length"], "79")
1725         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1726         self.assertEqual(config["exclude"], r"\.pyi?$")
1727         self.assertEqual(config["include"], r"\.py?$")
1728
1729     def test_find_project_root(self) -> None:
1730         with TemporaryDirectory() as workspace:
1731             root = Path(workspace)
1732             test_dir = root / "test"
1733             test_dir.mkdir()
1734
1735             src_dir = root / "src"
1736             src_dir.mkdir()
1737
1738             root_pyproject = root / "pyproject.toml"
1739             root_pyproject.touch()
1740             src_pyproject = src_dir / "pyproject.toml"
1741             src_pyproject.touch()
1742             src_python = src_dir / "foo.py"
1743             src_python.touch()
1744
1745             self.assertEqual(
1746                 black.find_project_root((src_dir, test_dir)), root.resolve()
1747             )
1748             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1749             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1750
1751     def test_bpo_33660_workaround(self) -> None:
1752         if system() == "Windows":
1753             return
1754
1755         # https://bugs.python.org/issue33660
1756
1757         old_cwd = Path.cwd()
1758         try:
1759             root = Path("/")
1760             os.chdir(str(root))
1761             path = Path("workspace") / "project"
1762             report = black.Report(verbose=True)
1763             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1764             self.assertEqual(normalized_path, "workspace/project")
1765         finally:
1766             os.chdir(str(old_cwd))
1767
1768
1769 with open(black.__file__, "r", encoding="utf-8") as _bf:
1770     black_source_lines = _bf.readlines()
1771
1772
1773 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
1774     """Show function calls `from black/__init__.py` as they happen.
1775
1776     Register this with `sys.settrace()` in a test you're debugging.
1777     """
1778     if event != "call":
1779         return tracefunc
1780
1781     stack = len(inspect.stack()) - 19
1782     stack *= 2
1783     filename = frame.f_code.co_filename
1784     lineno = frame.f_lineno
1785     func_sig_lineno = lineno - 1
1786     funcname = black_source_lines[func_sig_lineno].strip()
1787     while funcname.startswith("@"):
1788         func_sig_lineno += 1
1789         funcname = black_source_lines[func_sig_lineno].strip()
1790     if "black/__init__.py" in filename:
1791         print(f"{' ' * stack}{lineno}:{funcname}")
1792     return tracefunc
1793
1794
1795 if __name__ == "__main__":
1796     unittest.main(module="test_black")