]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

readme: Include Zulip in used-by section (#1987)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 from pathspec import PathSpec
38
39 # Import other test classes
40 from tests.util import (
41     THIS_DIR,
42     read_data,
43     DETERMINISTIC_HEADER,
44     BlackBaseTestCase,
45     DEFAULT_MODE,
46     fs,
47     ff,
48     dump_to_stderr,
49 )
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 THIS_FILE = Path(__file__)
54 PY36_VERSIONS = {
55     TargetVersion.PY36,
56     TargetVersion.PY37,
57     TargetVersion.PY38,
58     TargetVersion.PY39,
59 }
60 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
61 T = TypeVar("T")
62 R = TypeVar("R")
63
64
65 @contextmanager
66 def cache_dir(exists: bool = True) -> Iterator[Path]:
67     with TemporaryDirectory() as workspace:
68         cache_dir = Path(workspace)
69         if not exists:
70             cache_dir = cache_dir / "new"
71         with patch("black.CACHE_DIR", cache_dir):
72             yield cache_dir
73
74
75 @contextmanager
76 def event_loop() -> Iterator[None]:
77     policy = asyncio.get_event_loop_policy()
78     loop = policy.new_event_loop()
79     asyncio.set_event_loop(loop)
80     try:
81         yield
82
83     finally:
84         loop.close()
85
86
87 class FakeContext(click.Context):
88     """A fake click Context for when calling functions that need it."""
89
90     def __init__(self) -> None:
91         self.default_map: Dict[str, Any] = {}
92
93
94 class FakeParameter(click.Parameter):
95     """A fake click Parameter for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         pass
99
100
101 class BlackRunner(CliRunner):
102     """Modify CliRunner so that stderr is not merged with stdout.
103
104     This is a hack that can be removed once we depend on Click 7.x"""
105
106     def __init__(self) -> None:
107         self.stderrbuf = BytesIO()
108         self.stdoutbuf = BytesIO()
109         self.stdout_bytes = b""
110         self.stderr_bytes = b""
111         super().__init__()
112
113     @contextmanager
114     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
115         with super().isolation(*args, **kwargs) as output:
116             try:
117                 hold_stderr = sys.stderr
118                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
119                 yield output
120             finally:
121                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
122                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
123                 sys.stderr = hold_stderr
124
125
126 class BlackTestCase(BlackBaseTestCase):
127     def invokeBlack(
128         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
129     ) -> None:
130         runner = BlackRunner()
131         if ignore_config:
132             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
133         result = runner.invoke(black.main, args)
134         self.assertEqual(
135             result.exit_code,
136             exit_code,
137             msg=(
138                 f"Failed with args: {args}\n"
139                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
140                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
141                 f"exception: {result.exception}"
142             ),
143         )
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_empty(self) -> None:
147         source = expected = ""
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, DEFAULT_MODE)
152
153     def test_empty_ff(self) -> None:
154         expected = ""
155         tmp_file = Path(black.dump_to_file())
156         try:
157             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
158             with open(tmp_file, encoding="utf8") as f:
159                 actual = f.read()
160         finally:
161             os.unlink(tmp_file)
162         self.assertFormatEqual(expected, actual)
163
164     def test_piping(self) -> None:
165         source, expected = read_data("src/black/__init__", data=False)
166         result = BlackRunner().invoke(
167             black.main,
168             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         black.assert_equivalent(source, result.output)
174         black.assert_stable(source, result.output, DEFAULT_MODE)
175
176     def test_piping_diff(self) -> None:
177         diff_header = re.compile(
178             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
179             r"\+\d\d\d\d"
180         )
181         source, _ = read_data("expression.py")
182         expected, _ = read_data("expression.diff")
183         config = THIS_DIR / "data" / "empty_pyproject.toml"
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={config}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         config = THIS_DIR / "data" / "empty_pyproject.toml"
202         args = [
203             "-",
204             "--fast",
205             f"--line-length={black.DEFAULT_LINE_LENGTH}",
206             "--diff",
207             "--color",
208             f"--config={config}",
209         ]
210         result = BlackRunner().invoke(
211             black.main, args, input=BytesIO(source.encode("utf8"))
212         )
213         actual = result.output
214         # Again, the contents are checked in a different test, so only look for colors.
215         self.assertIn("\033[1;37m", actual)
216         self.assertIn("\033[36m", actual)
217         self.assertIn("\033[32m", actual)
218         self.assertIn("\033[31m", actual)
219         self.assertIn("\033[0m", actual)
220
221     @patch("black.dump_to_file", dump_to_stderr)
222     def _test_wip(self) -> None:
223         source, expected = read_data("wip")
224         sys.settrace(tracefunc)
225         mode = replace(
226             DEFAULT_MODE,
227             experimental_string_processing=False,
228             target_versions={black.TargetVersion.PY38},
229         )
230         actual = fs(source, mode=mode)
231         sys.settrace(None)
232         self.assertFormatEqual(expected, actual)
233         black.assert_equivalent(source, actual)
234         black.assert_stable(source, actual, black.FileMode())
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability1(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens1")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability2(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens2")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @unittest.expectedFailure
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability3(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens3")
254         actual = fs(source)
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_pep_572(self) -> None:
259         source, expected = read_data("pep_572")
260         actual = fs(source)
261         self.assertFormatEqual(expected, actual)
262         black.assert_stable(source, actual, DEFAULT_MODE)
263         if sys.version_info >= (3, 8):
264             black.assert_equivalent(source, actual)
265
266     def test_pep_572_version_detection(self) -> None:
267         source, _ = read_data("pep_572")
268         root = black.lib2to3_parse(source)
269         features = black.get_features_used(root)
270         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
271         versions = black.detect_target_versions(root)
272         self.assertIn(black.TargetVersion.PY38, versions)
273
274     def test_expression_ff(self) -> None:
275         source, expected = read_data("expression")
276         tmp_file = Path(black.dump_to_file(source))
277         try:
278             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
279             with open(tmp_file, encoding="utf8") as f:
280                 actual = f.read()
281         finally:
282             os.unlink(tmp_file)
283         self.assertFormatEqual(expected, actual)
284         with patch("black.dump_to_file", dump_to_stderr):
285             black.assert_equivalent(source, actual)
286             black.assert_stable(source, actual, DEFAULT_MODE)
287
288     def test_expression_diff(self) -> None:
289         source, _ = read_data("expression.py")
290         expected, _ = read_data("expression.diff")
291         tmp_file = Path(black.dump_to_file(source))
292         diff_header = re.compile(
293             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
294             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
295         )
296         try:
297             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
304         if expected != actual:
305             dump = black.dump_to_file(actual)
306             msg = (
307                 "Expected diff isn't equal to the actual. If you made changes to"
308                 " expression.py and this is an anticipated difference, overwrite"
309                 f" tests/data/expression.diff with {dump}"
310             )
311             self.assertEqual(expected, actual, msg)
312
313     def test_expression_diff_with_color(self) -> None:
314         source, _ = read_data("expression.py")
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file)]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_pep_570(self) -> None:
334         source, expected = read_data("pep_570")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_stable(source, actual, DEFAULT_MODE)
338         if sys.version_info >= (3, 8):
339             black.assert_equivalent(source, actual)
340
341     def test_detect_pos_only_arguments(self) -> None:
342         source, _ = read_data("pep_570")
343         root = black.lib2to3_parse(source)
344         features = black.get_features_used(root)
345         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
346         versions = black.detect_target_versions(root)
347         self.assertIn(black.TargetVersion.PY38, versions)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_string_quotes(self) -> None:
351         source, expected = read_data("string_quotes")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_equivalent(source, actual)
355         black.assert_stable(source, actual, DEFAULT_MODE)
356         mode = replace(DEFAULT_MODE, string_normalization=False)
357         not_normalized = fs(source, mode=mode)
358         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
359         black.assert_equivalent(source, not_normalized)
360         black.assert_stable(source, not_normalized, mode=mode)
361
362     @patch("black.dump_to_file", dump_to_stderr)
363     def test_docstring_no_string_normalization(self) -> None:
364         """Like test_docstring but with string normalization off."""
365         source, expected = read_data("docstring_no_string_normalization")
366         mode = replace(DEFAULT_MODE, string_normalization=False)
367         actual = fs(source, mode=mode)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, mode)
371
372     def test_long_strings_flag_disabled(self) -> None:
373         """Tests for turning off the string processing logic."""
374         source, expected = read_data("long_strings_flag_disabled")
375         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
376         actual = fs(source, mode=mode)
377         self.assertFormatEqual(expected, actual)
378         black.assert_stable(expected, actual, mode)
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_numeric_literals(self) -> None:
382         source, expected = read_data("numeric_literals")
383         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
384         actual = fs(source, mode=mode)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, mode)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_numeric_literals_ignoring_underscores(self) -> None:
391         source, expected = read_data("numeric_literals_skip_underscores")
392         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
393         actual = fs(source, mode=mode)
394         self.assertFormatEqual(expected, actual)
395         black.assert_equivalent(source, actual)
396         black.assert_stable(source, actual, mode)
397
398     def test_skip_magic_trailing_comma(self) -> None:
399         source, _ = read_data("expression.py")
400         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
401         tmp_file = Path(black.dump_to_file(source))
402         diff_header = re.compile(
403             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
404             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
405         )
406         try:
407             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
408             self.assertEqual(result.exit_code, 0)
409         finally:
410             os.unlink(tmp_file)
411         actual = result.output
412         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
413         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
414         if expected != actual:
415             dump = black.dump_to_file(actual)
416             msg = (
417                 "Expected diff isn't equal to the actual. If you made changes to"
418                 " expression.py and this is an anticipated difference, overwrite"
419                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
420             )
421             self.assertEqual(expected, actual, msg)
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_python2_print_function(self) -> None:
425         source, expected = read_data("python2_print_function")
426         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
427         actual = fs(source, mode=mode)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, mode)
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_stub(self) -> None:
434         mode = replace(DEFAULT_MODE, is_pyi=True)
435         source, expected = read_data("stub.pyi")
436         actual = fs(source, mode=mode)
437         self.assertFormatEqual(expected, actual)
438         black.assert_stable(source, actual, mode)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_async_as_identifier(self) -> None:
442         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
443         source, expected = read_data("async_as_identifier")
444         actual = fs(source)
445         self.assertFormatEqual(expected, actual)
446         major, minor = sys.version_info[:2]
447         if major < 3 or (major <= 3 and minor < 7):
448             black.assert_equivalent(source, actual)
449         black.assert_stable(source, actual, DEFAULT_MODE)
450         # ensure black can parse this when the target is 3.6
451         self.invokeBlack([str(source_path), "--target-version", "py36"])
452         # but not on 3.7, because async/await is no longer an identifier
453         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
454
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_python37(self) -> None:
457         source_path = (THIS_DIR / "data" / "python37.py").resolve()
458         source, expected = read_data("python37")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         major, minor = sys.version_info[:2]
462         if major > 3 or (major == 3 and minor >= 7):
463             black.assert_equivalent(source, actual)
464         black.assert_stable(source, actual, DEFAULT_MODE)
465         # ensure black can parse this when the target is 3.7
466         self.invokeBlack([str(source_path), "--target-version", "py37"])
467         # but not on 3.6, because we use async as a reserved keyword
468         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
469
470     @patch("black.dump_to_file", dump_to_stderr)
471     def test_python38(self) -> None:
472         source, expected = read_data("python38")
473         actual = fs(source)
474         self.assertFormatEqual(expected, actual)
475         major, minor = sys.version_info[:2]
476         if major > 3 or (major == 3 and minor >= 8):
477             black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, DEFAULT_MODE)
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python39(self) -> None:
482         source, expected = read_data("python39")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         major, minor = sys.version_info[:2]
486         if major > 3 or (major == 3 and minor >= 9):
487             black.assert_equivalent(source, actual)
488         black.assert_stable(source, actual, DEFAULT_MODE)
489
490     def test_tab_comment_indentation(self) -> None:
491         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
492         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
493         self.assertFormatEqual(contents_spc, fs(contents_spc))
494         self.assertFormatEqual(contents_spc, fs(contents_tab))
495
496         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
497         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
498         self.assertFormatEqual(contents_spc, fs(contents_spc))
499         self.assertFormatEqual(contents_spc, fs(contents_tab))
500
501         # mixed tabs and spaces (valid Python 2 code)
502         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
503         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
504         self.assertFormatEqual(contents_spc, fs(contents_spc))
505         self.assertFormatEqual(contents_spc, fs(contents_tab))
506
507         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
508         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
509         self.assertFormatEqual(contents_spc, fs(contents_spc))
510         self.assertFormatEqual(contents_spc, fs(contents_tab))
511
512     def test_report_verbose(self) -> None:
513         report = black.Report(verbose=True)
514         out_lines = []
515         err_lines = []
516
517         def out(msg: str, **kwargs: Any) -> None:
518             out_lines.append(msg)
519
520         def err(msg: str, **kwargs: Any) -> None:
521             err_lines.append(msg)
522
523         with patch("black.out", out), patch("black.err", err):
524             report.done(Path("f1"), black.Changed.NO)
525             self.assertEqual(len(out_lines), 1)
526             self.assertEqual(len(err_lines), 0)
527             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
528             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
529             self.assertEqual(report.return_code, 0)
530             report.done(Path("f2"), black.Changed.YES)
531             self.assertEqual(len(out_lines), 2)
532             self.assertEqual(len(err_lines), 0)
533             self.assertEqual(out_lines[-1], "reformatted f2")
534             self.assertEqual(
535                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
536             )
537             report.done(Path("f3"), black.Changed.CACHED)
538             self.assertEqual(len(out_lines), 3)
539             self.assertEqual(len(err_lines), 0)
540             self.assertEqual(
541                 out_lines[-1], "f3 wasn't modified on disk since last run."
542             )
543             self.assertEqual(
544                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
545             )
546             self.assertEqual(report.return_code, 0)
547             report.check = True
548             self.assertEqual(report.return_code, 1)
549             report.check = False
550             report.failed(Path("e1"), "boom")
551             self.assertEqual(len(out_lines), 3)
552             self.assertEqual(len(err_lines), 1)
553             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
554             self.assertEqual(
555                 unstyle(str(report)),
556                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
557                 " reformat.",
558             )
559             self.assertEqual(report.return_code, 123)
560             report.done(Path("f3"), black.Changed.YES)
561             self.assertEqual(len(out_lines), 4)
562             self.assertEqual(len(err_lines), 1)
563             self.assertEqual(out_lines[-1], "reformatted f3")
564             self.assertEqual(
565                 unstyle(str(report)),
566                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
567                 " reformat.",
568             )
569             self.assertEqual(report.return_code, 123)
570             report.failed(Path("e2"), "boom")
571             self.assertEqual(len(out_lines), 4)
572             self.assertEqual(len(err_lines), 2)
573             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
574             self.assertEqual(
575                 unstyle(str(report)),
576                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
577                 " reformat.",
578             )
579             self.assertEqual(report.return_code, 123)
580             report.path_ignored(Path("wat"), "no match")
581             self.assertEqual(len(out_lines), 5)
582             self.assertEqual(len(err_lines), 2)
583             self.assertEqual(out_lines[-1], "wat ignored: no match")
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
587                 " reformat.",
588             )
589             self.assertEqual(report.return_code, 123)
590             report.done(Path("f4"), black.Changed.NO)
591             self.assertEqual(len(out_lines), 6)
592             self.assertEqual(len(err_lines), 2)
593             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
594             self.assertEqual(
595                 unstyle(str(report)),
596                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
597                 " reformat.",
598             )
599             self.assertEqual(report.return_code, 123)
600             report.check = True
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
604                 " would fail to reformat.",
605             )
606             report.check = False
607             report.diff = True
608             self.assertEqual(
609                 unstyle(str(report)),
610                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
611                 " would fail to reformat.",
612             )
613
614     def test_report_quiet(self) -> None:
615         report = black.Report(quiet=True)
616         out_lines = []
617         err_lines = []
618
619         def out(msg: str, **kwargs: Any) -> None:
620             out_lines.append(msg)
621
622         def err(msg: str, **kwargs: Any) -> None:
623             err_lines.append(msg)
624
625         with patch("black.out", out), patch("black.err", err):
626             report.done(Path("f1"), black.Changed.NO)
627             self.assertEqual(len(out_lines), 0)
628             self.assertEqual(len(err_lines), 0)
629             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
630             self.assertEqual(report.return_code, 0)
631             report.done(Path("f2"), black.Changed.YES)
632             self.assertEqual(len(out_lines), 0)
633             self.assertEqual(len(err_lines), 0)
634             self.assertEqual(
635                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
636             )
637             report.done(Path("f3"), black.Changed.CACHED)
638             self.assertEqual(len(out_lines), 0)
639             self.assertEqual(len(err_lines), 0)
640             self.assertEqual(
641                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
642             )
643             self.assertEqual(report.return_code, 0)
644             report.check = True
645             self.assertEqual(report.return_code, 1)
646             report.check = False
647             report.failed(Path("e1"), "boom")
648             self.assertEqual(len(out_lines), 0)
649             self.assertEqual(len(err_lines), 1)
650             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
651             self.assertEqual(
652                 unstyle(str(report)),
653                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
654                 " reformat.",
655             )
656             self.assertEqual(report.return_code, 123)
657             report.done(Path("f3"), black.Changed.YES)
658             self.assertEqual(len(out_lines), 0)
659             self.assertEqual(len(err_lines), 1)
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.failed(Path("e2"), "boom")
667             self.assertEqual(len(out_lines), 0)
668             self.assertEqual(len(err_lines), 2)
669             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.path_ignored(Path("wat"), "no match")
677             self.assertEqual(len(out_lines), 0)
678             self.assertEqual(len(err_lines), 2)
679             self.assertEqual(
680                 unstyle(str(report)),
681                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
682                 " reformat.",
683             )
684             self.assertEqual(report.return_code, 123)
685             report.done(Path("f4"), black.Changed.NO)
686             self.assertEqual(len(out_lines), 0)
687             self.assertEqual(len(err_lines), 2)
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
691                 " reformat.",
692             )
693             self.assertEqual(report.return_code, 123)
694             report.check = True
695             self.assertEqual(
696                 unstyle(str(report)),
697                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
698                 " would fail to reformat.",
699             )
700             report.check = False
701             report.diff = True
702             self.assertEqual(
703                 unstyle(str(report)),
704                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
705                 " would fail to reformat.",
706             )
707
708     def test_report_normal(self) -> None:
709         report = black.Report()
710         out_lines = []
711         err_lines = []
712
713         def out(msg: str, **kwargs: Any) -> None:
714             out_lines.append(msg)
715
716         def err(msg: str, **kwargs: Any) -> None:
717             err_lines.append(msg)
718
719         with patch("black.out", out), patch("black.err", err):
720             report.done(Path("f1"), black.Changed.NO)
721             self.assertEqual(len(out_lines), 0)
722             self.assertEqual(len(err_lines), 0)
723             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
724             self.assertEqual(report.return_code, 0)
725             report.done(Path("f2"), black.Changed.YES)
726             self.assertEqual(len(out_lines), 1)
727             self.assertEqual(len(err_lines), 0)
728             self.assertEqual(out_lines[-1], "reformatted f2")
729             self.assertEqual(
730                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
731             )
732             report.done(Path("f3"), black.Changed.CACHED)
733             self.assertEqual(len(out_lines), 1)
734             self.assertEqual(len(err_lines), 0)
735             self.assertEqual(out_lines[-1], "reformatted f2")
736             self.assertEqual(
737                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
738             )
739             self.assertEqual(report.return_code, 0)
740             report.check = True
741             self.assertEqual(report.return_code, 1)
742             report.check = False
743             report.failed(Path("e1"), "boom")
744             self.assertEqual(len(out_lines), 1)
745             self.assertEqual(len(err_lines), 1)
746             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
747             self.assertEqual(
748                 unstyle(str(report)),
749                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
750                 " reformat.",
751             )
752             self.assertEqual(report.return_code, 123)
753             report.done(Path("f3"), black.Changed.YES)
754             self.assertEqual(len(out_lines), 2)
755             self.assertEqual(len(err_lines), 1)
756             self.assertEqual(out_lines[-1], "reformatted f3")
757             self.assertEqual(
758                 unstyle(str(report)),
759                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
760                 " reformat.",
761             )
762             self.assertEqual(report.return_code, 123)
763             report.failed(Path("e2"), "boom")
764             self.assertEqual(len(out_lines), 2)
765             self.assertEqual(len(err_lines), 2)
766             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
767             self.assertEqual(
768                 unstyle(str(report)),
769                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
770                 " reformat.",
771             )
772             self.assertEqual(report.return_code, 123)
773             report.path_ignored(Path("wat"), "no match")
774             self.assertEqual(len(out_lines), 2)
775             self.assertEqual(len(err_lines), 2)
776             self.assertEqual(
777                 unstyle(str(report)),
778                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
779                 " reformat.",
780             )
781             self.assertEqual(report.return_code, 123)
782             report.done(Path("f4"), black.Changed.NO)
783             self.assertEqual(len(out_lines), 2)
784             self.assertEqual(len(err_lines), 2)
785             self.assertEqual(
786                 unstyle(str(report)),
787                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
788                 " reformat.",
789             )
790             self.assertEqual(report.return_code, 123)
791             report.check = True
792             self.assertEqual(
793                 unstyle(str(report)),
794                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
795                 " would fail to reformat.",
796             )
797             report.check = False
798             report.diff = True
799             self.assertEqual(
800                 unstyle(str(report)),
801                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
802                 " would fail to reformat.",
803             )
804
805     def test_lib2to3_parse(self) -> None:
806         with self.assertRaises(black.InvalidInput):
807             black.lib2to3_parse("invalid syntax")
808
809         straddling = "x + y"
810         black.lib2to3_parse(straddling)
811         black.lib2to3_parse(straddling, {TargetVersion.PY27})
812         black.lib2to3_parse(straddling, {TargetVersion.PY36})
813         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
814
815         py2_only = "print x"
816         black.lib2to3_parse(py2_only)
817         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
818         with self.assertRaises(black.InvalidInput):
819             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
820         with self.assertRaises(black.InvalidInput):
821             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
822
823         py3_only = "exec(x, end=y)"
824         black.lib2to3_parse(py3_only)
825         with self.assertRaises(black.InvalidInput):
826             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
827         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
828         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
829
830     def test_get_features_used_decorator(self) -> None:
831         # Test the feature detection of new decorator syntax
832         # since this makes some test cases of test_get_features_used()
833         # fails if it fails, this is tested first so that a useful case
834         # is identified
835         simples, relaxed = read_data("decorators")
836         # skip explanation comments at the top of the file
837         for simple_test in simples.split("##")[1:]:
838             node = black.lib2to3_parse(simple_test)
839             decorator = str(node.children[0].children[0]).strip()
840             self.assertNotIn(
841                 Feature.RELAXED_DECORATORS,
842                 black.get_features_used(node),
843                 msg=(
844                     f"decorator '{decorator}' follows python<=3.8 syntax"
845                     "but is detected as 3.9+"
846                     # f"The full node is\n{node!r}"
847                 ),
848             )
849         # skip the '# output' comment at the top of the output part
850         for relaxed_test in relaxed.split("##")[1:]:
851             node = black.lib2to3_parse(relaxed_test)
852             decorator = str(node.children[0].children[0]).strip()
853             self.assertIn(
854                 Feature.RELAXED_DECORATORS,
855                 black.get_features_used(node),
856                 msg=(
857                     f"decorator '{decorator}' uses python3.9+ syntax"
858                     "but is detected as python<=3.8"
859                     # f"The full node is\n{node!r}"
860                 ),
861             )
862
863     def test_get_features_used(self) -> None:
864         node = black.lib2to3_parse("def f(*, arg): ...\n")
865         self.assertEqual(black.get_features_used(node), set())
866         node = black.lib2to3_parse("def f(*, arg,): ...\n")
867         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
868         node = black.lib2to3_parse("f(*arg,)\n")
869         self.assertEqual(
870             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
871         )
872         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
873         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
874         node = black.lib2to3_parse("123_456\n")
875         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
876         node = black.lib2to3_parse("123456\n")
877         self.assertEqual(black.get_features_used(node), set())
878         source, expected = read_data("function")
879         node = black.lib2to3_parse(source)
880         expected_features = {
881             Feature.TRAILING_COMMA_IN_CALL,
882             Feature.TRAILING_COMMA_IN_DEF,
883             Feature.F_STRINGS,
884         }
885         self.assertEqual(black.get_features_used(node), expected_features)
886         node = black.lib2to3_parse(expected)
887         self.assertEqual(black.get_features_used(node), expected_features)
888         source, expected = read_data("expression")
889         node = black.lib2to3_parse(source)
890         self.assertEqual(black.get_features_used(node), set())
891         node = black.lib2to3_parse(expected)
892         self.assertEqual(black.get_features_used(node), set())
893
894     def test_get_future_imports(self) -> None:
895         node = black.lib2to3_parse("\n")
896         self.assertEqual(set(), black.get_future_imports(node))
897         node = black.lib2to3_parse("from __future__ import black\n")
898         self.assertEqual({"black"}, black.get_future_imports(node))
899         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
900         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
901         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
902         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
903         node = black.lib2to3_parse(
904             "from __future__ import multiple\nfrom __future__ import imports\n"
905         )
906         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
907         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
908         self.assertEqual({"black"}, black.get_future_imports(node))
909         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
910         self.assertEqual({"black"}, black.get_future_imports(node))
911         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
912         self.assertEqual(set(), black.get_future_imports(node))
913         node = black.lib2to3_parse("from some.module import black\n")
914         self.assertEqual(set(), black.get_future_imports(node))
915         node = black.lib2to3_parse(
916             "from __future__ import unicode_literals as _unicode_literals"
917         )
918         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
919         node = black.lib2to3_parse(
920             "from __future__ import unicode_literals as _lol, print"
921         )
922         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
923
924     def test_debug_visitor(self) -> None:
925         source, _ = read_data("debug_visitor.py")
926         expected, _ = read_data("debug_visitor.out")
927         out_lines = []
928         err_lines = []
929
930         def out(msg: str, **kwargs: Any) -> None:
931             out_lines.append(msg)
932
933         def err(msg: str, **kwargs: Any) -> None:
934             err_lines.append(msg)
935
936         with patch("black.out", out), patch("black.err", err):
937             black.DebugVisitor.show(source)
938         actual = "\n".join(out_lines) + "\n"
939         log_name = ""
940         if expected != actual:
941             log_name = black.dump_to_file(*out_lines)
942         self.assertEqual(
943             expected,
944             actual,
945             f"AST print out is different. Actual version dumped to {log_name}",
946         )
947
948     def test_format_file_contents(self) -> None:
949         empty = ""
950         mode = DEFAULT_MODE
951         with self.assertRaises(black.NothingChanged):
952             black.format_file_contents(empty, mode=mode, fast=False)
953         just_nl = "\n"
954         with self.assertRaises(black.NothingChanged):
955             black.format_file_contents(just_nl, mode=mode, fast=False)
956         same = "j = [1, 2, 3]\n"
957         with self.assertRaises(black.NothingChanged):
958             black.format_file_contents(same, mode=mode, fast=False)
959         different = "j = [1,2,3]"
960         expected = same
961         actual = black.format_file_contents(different, mode=mode, fast=False)
962         self.assertEqual(expected, actual)
963         invalid = "return if you can"
964         with self.assertRaises(black.InvalidInput) as e:
965             black.format_file_contents(invalid, mode=mode, fast=False)
966         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
967
968     def test_endmarker(self) -> None:
969         n = black.lib2to3_parse("\n")
970         self.assertEqual(n.type, black.syms.file_input)
971         self.assertEqual(len(n.children), 1)
972         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
973
974     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
975     def test_assertFormatEqual(self) -> None:
976         out_lines = []
977         err_lines = []
978
979         def out(msg: str, **kwargs: Any) -> None:
980             out_lines.append(msg)
981
982         def err(msg: str, **kwargs: Any) -> None:
983             err_lines.append(msg)
984
985         with patch("black.out", out), patch("black.err", err):
986             with self.assertRaises(AssertionError):
987                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
988
989         out_str = "".join(out_lines)
990         self.assertTrue("Expected tree:" in out_str)
991         self.assertTrue("Actual tree:" in out_str)
992         self.assertEqual("".join(err_lines), "")
993
994     def test_cache_broken_file(self) -> None:
995         mode = DEFAULT_MODE
996         with cache_dir() as workspace:
997             cache_file = black.get_cache_file(mode)
998             with cache_file.open("w") as fobj:
999                 fobj.write("this is not a pickle")
1000             self.assertEqual(black.read_cache(mode), {})
1001             src = (workspace / "test.py").resolve()
1002             with src.open("w") as fobj:
1003                 fobj.write("print('hello')")
1004             self.invokeBlack([str(src)])
1005             cache = black.read_cache(mode)
1006             self.assertIn(str(src), cache)
1007
1008     def test_cache_single_file_already_cached(self) -> None:
1009         mode = DEFAULT_MODE
1010         with cache_dir() as workspace:
1011             src = (workspace / "test.py").resolve()
1012             with src.open("w") as fobj:
1013                 fobj.write("print('hello')")
1014             black.write_cache({}, [src], mode)
1015             self.invokeBlack([str(src)])
1016             with src.open("r") as fobj:
1017                 self.assertEqual(fobj.read(), "print('hello')")
1018
1019     @event_loop()
1020     def test_cache_multiple_files(self) -> None:
1021         mode = DEFAULT_MODE
1022         with cache_dir() as workspace, patch(
1023             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1024         ):
1025             one = (workspace / "one.py").resolve()
1026             with one.open("w") as fobj:
1027                 fobj.write("print('hello')")
1028             two = (workspace / "two.py").resolve()
1029             with two.open("w") as fobj:
1030                 fobj.write("print('hello')")
1031             black.write_cache({}, [one], mode)
1032             self.invokeBlack([str(workspace)])
1033             with one.open("r") as fobj:
1034                 self.assertEqual(fobj.read(), "print('hello')")
1035             with two.open("r") as fobj:
1036                 self.assertEqual(fobj.read(), 'print("hello")\n')
1037             cache = black.read_cache(mode)
1038             self.assertIn(str(one), cache)
1039             self.assertIn(str(two), cache)
1040
1041     def test_no_cache_when_writeback_diff(self) -> None:
1042         mode = DEFAULT_MODE
1043         with cache_dir() as workspace:
1044             src = (workspace / "test.py").resolve()
1045             with src.open("w") as fobj:
1046                 fobj.write("print('hello')")
1047             with patch("black.read_cache") as read_cache, patch(
1048                 "black.write_cache"
1049             ) as write_cache:
1050                 self.invokeBlack([str(src), "--diff"])
1051                 cache_file = black.get_cache_file(mode)
1052                 self.assertFalse(cache_file.exists())
1053                 write_cache.assert_not_called()
1054                 read_cache.assert_not_called()
1055
1056     def test_no_cache_when_writeback_color_diff(self) -> None:
1057         mode = DEFAULT_MODE
1058         with cache_dir() as workspace:
1059             src = (workspace / "test.py").resolve()
1060             with src.open("w") as fobj:
1061                 fobj.write("print('hello')")
1062             with patch("black.read_cache") as read_cache, patch(
1063                 "black.write_cache"
1064             ) as write_cache:
1065                 self.invokeBlack([str(src), "--diff", "--color"])
1066                 cache_file = black.get_cache_file(mode)
1067                 self.assertFalse(cache_file.exists())
1068                 write_cache.assert_not_called()
1069                 read_cache.assert_not_called()
1070
1071     @event_loop()
1072     def test_output_locking_when_writeback_diff(self) -> None:
1073         with cache_dir() as workspace:
1074             for tag in range(0, 4):
1075                 src = (workspace / f"test{tag}.py").resolve()
1076                 with src.open("w") as fobj:
1077                     fobj.write("print('hello')")
1078             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1079                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1080                 # this isn't quite doing what we want, but if it _isn't_
1081                 # called then we cannot be using the lock it provides
1082                 mgr.assert_called()
1083
1084     @event_loop()
1085     def test_output_locking_when_writeback_color_diff(self) -> None:
1086         with cache_dir() as workspace:
1087             for tag in range(0, 4):
1088                 src = (workspace / f"test{tag}.py").resolve()
1089                 with src.open("w") as fobj:
1090                     fobj.write("print('hello')")
1091             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1092                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1093                 # this isn't quite doing what we want, but if it _isn't_
1094                 # called then we cannot be using the lock it provides
1095                 mgr.assert_called()
1096
1097     def test_no_cache_when_stdin(self) -> None:
1098         mode = DEFAULT_MODE
1099         with cache_dir():
1100             result = CliRunner().invoke(
1101                 black.main, ["-"], input=BytesIO(b"print('hello')")
1102             )
1103             self.assertEqual(result.exit_code, 0)
1104             cache_file = black.get_cache_file(mode)
1105             self.assertFalse(cache_file.exists())
1106
1107     def test_read_cache_no_cachefile(self) -> None:
1108         mode = DEFAULT_MODE
1109         with cache_dir():
1110             self.assertEqual(black.read_cache(mode), {})
1111
1112     def test_write_cache_read_cache(self) -> None:
1113         mode = DEFAULT_MODE
1114         with cache_dir() as workspace:
1115             src = (workspace / "test.py").resolve()
1116             src.touch()
1117             black.write_cache({}, [src], mode)
1118             cache = black.read_cache(mode)
1119             self.assertIn(str(src), cache)
1120             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1121
1122     def test_filter_cached(self) -> None:
1123         with TemporaryDirectory() as workspace:
1124             path = Path(workspace)
1125             uncached = (path / "uncached").resolve()
1126             cached = (path / "cached").resolve()
1127             cached_but_changed = (path / "changed").resolve()
1128             uncached.touch()
1129             cached.touch()
1130             cached_but_changed.touch()
1131             cache = {
1132                 str(cached): black.get_cache_info(cached),
1133                 str(cached_but_changed): (0.0, 0),
1134             }
1135             todo, done = black.filter_cached(
1136                 cache, {uncached, cached, cached_but_changed}
1137             )
1138             self.assertEqual(todo, {uncached, cached_but_changed})
1139             self.assertEqual(done, {cached})
1140
1141     def test_write_cache_creates_directory_if_needed(self) -> None:
1142         mode = DEFAULT_MODE
1143         with cache_dir(exists=False) as workspace:
1144             self.assertFalse(workspace.exists())
1145             black.write_cache({}, [], mode)
1146             self.assertTrue(workspace.exists())
1147
1148     @event_loop()
1149     def test_failed_formatting_does_not_get_cached(self) -> None:
1150         mode = DEFAULT_MODE
1151         with cache_dir() as workspace, patch(
1152             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1153         ):
1154             failing = (workspace / "failing.py").resolve()
1155             with failing.open("w") as fobj:
1156                 fobj.write("not actually python")
1157             clean = (workspace / "clean.py").resolve()
1158             with clean.open("w") as fobj:
1159                 fobj.write('print("hello")\n')
1160             self.invokeBlack([str(workspace)], exit_code=123)
1161             cache = black.read_cache(mode)
1162             self.assertNotIn(str(failing), cache)
1163             self.assertIn(str(clean), cache)
1164
1165     def test_write_cache_write_fail(self) -> None:
1166         mode = DEFAULT_MODE
1167         with cache_dir(), patch.object(Path, "open") as mock:
1168             mock.side_effect = OSError
1169             black.write_cache({}, [], mode)
1170
1171     @event_loop()
1172     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1173     def test_works_in_mono_process_only_environment(self) -> None:
1174         with cache_dir() as workspace:
1175             for f in [
1176                 (workspace / "one.py").resolve(),
1177                 (workspace / "two.py").resolve(),
1178             ]:
1179                 f.write_text('print("hello")\n')
1180             self.invokeBlack([str(workspace)])
1181
1182     @event_loop()
1183     def test_check_diff_use_together(self) -> None:
1184         with cache_dir():
1185             # Files which will be reformatted.
1186             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1187             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1188             # Files which will not be reformatted.
1189             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1190             self.invokeBlack([str(src2), "--diff", "--check"])
1191             # Multi file command.
1192             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1193
1194     def test_no_files(self) -> None:
1195         with cache_dir():
1196             # Without an argument, black exits with error code 0.
1197             self.invokeBlack([])
1198
1199     def test_broken_symlink(self) -> None:
1200         with cache_dir() as workspace:
1201             symlink = workspace / "broken_link.py"
1202             try:
1203                 symlink.symlink_to("nonexistent.py")
1204             except OSError as e:
1205                 self.skipTest(f"Can't create symlinks: {e}")
1206             self.invokeBlack([str(workspace.resolve())])
1207
1208     def test_read_cache_line_lengths(self) -> None:
1209         mode = DEFAULT_MODE
1210         short_mode = replace(DEFAULT_MODE, line_length=1)
1211         with cache_dir() as workspace:
1212             path = (workspace / "file.py").resolve()
1213             path.touch()
1214             black.write_cache({}, [path], mode)
1215             one = black.read_cache(mode)
1216             self.assertIn(str(path), one)
1217             two = black.read_cache(short_mode)
1218             self.assertNotIn(str(path), two)
1219
1220     def test_single_file_force_pyi(self) -> None:
1221         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1222         contents, expected = read_data("force_pyi")
1223         with cache_dir() as workspace:
1224             path = (workspace / "file.py").resolve()
1225             with open(path, "w") as fh:
1226                 fh.write(contents)
1227             self.invokeBlack([str(path), "--pyi"])
1228             with open(path, "r") as fh:
1229                 actual = fh.read()
1230             # verify cache with --pyi is separate
1231             pyi_cache = black.read_cache(pyi_mode)
1232             self.assertIn(str(path), pyi_cache)
1233             normal_cache = black.read_cache(DEFAULT_MODE)
1234             self.assertNotIn(str(path), normal_cache)
1235         self.assertFormatEqual(expected, actual)
1236         black.assert_equivalent(contents, actual)
1237         black.assert_stable(contents, actual, pyi_mode)
1238
1239     @event_loop()
1240     def test_multi_file_force_pyi(self) -> None:
1241         reg_mode = DEFAULT_MODE
1242         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1243         contents, expected = read_data("force_pyi")
1244         with cache_dir() as workspace:
1245             paths = [
1246                 (workspace / "file1.py").resolve(),
1247                 (workspace / "file2.py").resolve(),
1248             ]
1249             for path in paths:
1250                 with open(path, "w") as fh:
1251                     fh.write(contents)
1252             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1253             for path in paths:
1254                 with open(path, "r") as fh:
1255                     actual = fh.read()
1256                 self.assertEqual(actual, expected)
1257             # verify cache with --pyi is separate
1258             pyi_cache = black.read_cache(pyi_mode)
1259             normal_cache = black.read_cache(reg_mode)
1260             for path in paths:
1261                 self.assertIn(str(path), pyi_cache)
1262                 self.assertNotIn(str(path), normal_cache)
1263
1264     def test_pipe_force_pyi(self) -> None:
1265         source, expected = read_data("force_pyi")
1266         result = CliRunner().invoke(
1267             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1268         )
1269         self.assertEqual(result.exit_code, 0)
1270         actual = result.output
1271         self.assertFormatEqual(actual, expected)
1272
1273     def test_single_file_force_py36(self) -> None:
1274         reg_mode = DEFAULT_MODE
1275         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1276         source, expected = read_data("force_py36")
1277         with cache_dir() as workspace:
1278             path = (workspace / "file.py").resolve()
1279             with open(path, "w") as fh:
1280                 fh.write(source)
1281             self.invokeBlack([str(path), *PY36_ARGS])
1282             with open(path, "r") as fh:
1283                 actual = fh.read()
1284             # verify cache with --target-version is separate
1285             py36_cache = black.read_cache(py36_mode)
1286             self.assertIn(str(path), py36_cache)
1287             normal_cache = black.read_cache(reg_mode)
1288             self.assertNotIn(str(path), normal_cache)
1289         self.assertEqual(actual, expected)
1290
1291     @event_loop()
1292     def test_multi_file_force_py36(self) -> None:
1293         reg_mode = DEFAULT_MODE
1294         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1295         source, expected = read_data("force_py36")
1296         with cache_dir() as workspace:
1297             paths = [
1298                 (workspace / "file1.py").resolve(),
1299                 (workspace / "file2.py").resolve(),
1300             ]
1301             for path in paths:
1302                 with open(path, "w") as fh:
1303                     fh.write(source)
1304             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1305             for path in paths:
1306                 with open(path, "r") as fh:
1307                     actual = fh.read()
1308                 self.assertEqual(actual, expected)
1309             # verify cache with --target-version is separate
1310             pyi_cache = black.read_cache(py36_mode)
1311             normal_cache = black.read_cache(reg_mode)
1312             for path in paths:
1313                 self.assertIn(str(path), pyi_cache)
1314                 self.assertNotIn(str(path), normal_cache)
1315
1316     def test_pipe_force_py36(self) -> None:
1317         source, expected = read_data("force_py36")
1318         result = CliRunner().invoke(
1319             black.main,
1320             ["-", "-q", "--target-version=py36"],
1321             input=BytesIO(source.encode("utf8")),
1322         )
1323         self.assertEqual(result.exit_code, 0)
1324         actual = result.output
1325         self.assertFormatEqual(actual, expected)
1326
1327     def test_include_exclude(self) -> None:
1328         path = THIS_DIR / "data" / "include_exclude_tests"
1329         include = re.compile(r"\.pyi?$")
1330         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1331         report = black.Report()
1332         gitignore = PathSpec.from_lines("gitwildmatch", [])
1333         sources: List[Path] = []
1334         expected = [
1335             Path(path / "b/dont_exclude/a.py"),
1336             Path(path / "b/dont_exclude/a.pyi"),
1337         ]
1338         this_abs = THIS_DIR.resolve()
1339         sources.extend(
1340             black.gen_python_files(
1341                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1342             )
1343         )
1344         self.assertEqual(sorted(expected), sorted(sources))
1345
1346     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1347     def test_exclude_for_issue_1572(self) -> None:
1348         # Exclude shouldn't touch files that were explicitly given to Black through the
1349         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1350         # https://github.com/psf/black/issues/1572
1351         path = THIS_DIR / "data" / "include_exclude_tests"
1352         include = ""
1353         exclude = r"/exclude/|a\.py"
1354         src = str(path / "b/exclude/a.py")
1355         report = black.Report()
1356         expected = [Path(path / "b/exclude/a.py")]
1357         sources = list(
1358             black.get_sources(
1359                 ctx=FakeContext(),
1360                 src=(src,),
1361                 quiet=True,
1362                 verbose=False,
1363                 include=include,
1364                 exclude=exclude,
1365                 force_exclude=None,
1366                 report=report,
1367                 stdin_filename=None,
1368             )
1369         )
1370         self.assertEqual(sorted(expected), sorted(sources))
1371
1372     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1373     def test_get_sources_with_stdin(self) -> None:
1374         include = ""
1375         exclude = r"/exclude/|a\.py"
1376         src = "-"
1377         report = black.Report()
1378         expected = [Path("-")]
1379         sources = list(
1380             black.get_sources(
1381                 ctx=FakeContext(),
1382                 src=(src,),
1383                 quiet=True,
1384                 verbose=False,
1385                 include=include,
1386                 exclude=exclude,
1387                 force_exclude=None,
1388                 report=report,
1389                 stdin_filename=None,
1390             )
1391         )
1392         self.assertEqual(sorted(expected), sorted(sources))
1393
1394     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1395     def test_get_sources_with_stdin_filename(self) -> None:
1396         include = ""
1397         exclude = r"/exclude/|a\.py"
1398         src = "-"
1399         report = black.Report()
1400         stdin_filename = str(THIS_DIR / "data/collections.py")
1401         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1402         sources = list(
1403             black.get_sources(
1404                 ctx=FakeContext(),
1405                 src=(src,),
1406                 quiet=True,
1407                 verbose=False,
1408                 include=include,
1409                 exclude=exclude,
1410                 force_exclude=None,
1411                 report=report,
1412                 stdin_filename=stdin_filename,
1413             )
1414         )
1415         self.assertEqual(sorted(expected), sorted(sources))
1416
1417     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1418     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1419         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1420         # file being passed directly. This is the same as
1421         # test_exclude_for_issue_1572
1422         path = THIS_DIR / "data" / "include_exclude_tests"
1423         include = ""
1424         exclude = r"/exclude/|a\.py"
1425         src = "-"
1426         report = black.Report()
1427         stdin_filename = str(path / "b/exclude/a.py")
1428         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1429         sources = list(
1430             black.get_sources(
1431                 ctx=FakeContext(),
1432                 src=(src,),
1433                 quiet=True,
1434                 verbose=False,
1435                 include=include,
1436                 exclude=exclude,
1437                 force_exclude=None,
1438                 report=report,
1439                 stdin_filename=stdin_filename,
1440             )
1441         )
1442         self.assertEqual(sorted(expected), sorted(sources))
1443
1444     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1445     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1446         # Force exclude should exclude the file when passing it through
1447         # stdin_filename
1448         path = THIS_DIR / "data" / "include_exclude_tests"
1449         include = ""
1450         force_exclude = r"/exclude/|a\.py"
1451         src = "-"
1452         report = black.Report()
1453         stdin_filename = str(path / "b/exclude/a.py")
1454         sources = list(
1455             black.get_sources(
1456                 ctx=FakeContext(),
1457                 src=(src,),
1458                 quiet=True,
1459                 verbose=False,
1460                 include=include,
1461                 exclude="",
1462                 force_exclude=force_exclude,
1463                 report=report,
1464                 stdin_filename=stdin_filename,
1465             )
1466         )
1467         self.assertEqual([], sorted(sources))
1468
1469     def test_reformat_one_with_stdin(self) -> None:
1470         with patch(
1471             "black.format_stdin_to_stdout",
1472             return_value=lambda *args, **kwargs: black.Changed.YES,
1473         ) as fsts:
1474             report = MagicMock()
1475             path = Path("-")
1476             black.reformat_one(
1477                 path,
1478                 fast=True,
1479                 write_back=black.WriteBack.YES,
1480                 mode=DEFAULT_MODE,
1481                 report=report,
1482             )
1483             fsts.assert_called_once()
1484             report.done.assert_called_with(path, black.Changed.YES)
1485
1486     def test_reformat_one_with_stdin_filename(self) -> None:
1487         with patch(
1488             "black.format_stdin_to_stdout",
1489             return_value=lambda *args, **kwargs: black.Changed.YES,
1490         ) as fsts:
1491             report = MagicMock()
1492             p = "foo.py"
1493             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1494             expected = Path(p)
1495             black.reformat_one(
1496                 path,
1497                 fast=True,
1498                 write_back=black.WriteBack.YES,
1499                 mode=DEFAULT_MODE,
1500                 report=report,
1501             )
1502             fsts.assert_called_once()
1503             # __BLACK_STDIN_FILENAME__ should have been striped
1504             report.done.assert_called_with(expected, black.Changed.YES)
1505
1506     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1507         with patch(
1508             "black.format_stdin_to_stdout",
1509             return_value=lambda *args, **kwargs: black.Changed.YES,
1510         ) as fsts:
1511             report = MagicMock()
1512             # Even with an existing file, since we are forcing stdin, black
1513             # should output to stdout and not modify the file inplace
1514             p = Path(str(THIS_DIR / "data/collections.py"))
1515             # Make sure is_file actually returns True
1516             self.assertTrue(p.is_file())
1517             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1518             expected = Path(p)
1519             black.reformat_one(
1520                 path,
1521                 fast=True,
1522                 write_back=black.WriteBack.YES,
1523                 mode=DEFAULT_MODE,
1524                 report=report,
1525             )
1526             fsts.assert_called_once()
1527             # __BLACK_STDIN_FILENAME__ should have been striped
1528             report.done.assert_called_with(expected, black.Changed.YES)
1529
1530     def test_gitignore_exclude(self) -> None:
1531         path = THIS_DIR / "data" / "include_exclude_tests"
1532         include = re.compile(r"\.pyi?$")
1533         exclude = re.compile(r"")
1534         report = black.Report()
1535         gitignore = PathSpec.from_lines(
1536             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1537         )
1538         sources: List[Path] = []
1539         expected = [
1540             Path(path / "b/dont_exclude/a.py"),
1541             Path(path / "b/dont_exclude/a.pyi"),
1542         ]
1543         this_abs = THIS_DIR.resolve()
1544         sources.extend(
1545             black.gen_python_files(
1546                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1547             )
1548         )
1549         self.assertEqual(sorted(expected), sorted(sources))
1550
1551     def test_empty_include(self) -> None:
1552         path = THIS_DIR / "data" / "include_exclude_tests"
1553         report = black.Report()
1554         gitignore = PathSpec.from_lines("gitwildmatch", [])
1555         empty = re.compile(r"")
1556         sources: List[Path] = []
1557         expected = [
1558             Path(path / "b/exclude/a.pie"),
1559             Path(path / "b/exclude/a.py"),
1560             Path(path / "b/exclude/a.pyi"),
1561             Path(path / "b/dont_exclude/a.pie"),
1562             Path(path / "b/dont_exclude/a.py"),
1563             Path(path / "b/dont_exclude/a.pyi"),
1564             Path(path / "b/.definitely_exclude/a.pie"),
1565             Path(path / "b/.definitely_exclude/a.py"),
1566             Path(path / "b/.definitely_exclude/a.pyi"),
1567         ]
1568         this_abs = THIS_DIR.resolve()
1569         sources.extend(
1570             black.gen_python_files(
1571                 path.iterdir(),
1572                 this_abs,
1573                 empty,
1574                 re.compile(black.DEFAULT_EXCLUDES),
1575                 None,
1576                 report,
1577                 gitignore,
1578             )
1579         )
1580         self.assertEqual(sorted(expected), sorted(sources))
1581
1582     def test_empty_exclude(self) -> None:
1583         path = THIS_DIR / "data" / "include_exclude_tests"
1584         report = black.Report()
1585         gitignore = PathSpec.from_lines("gitwildmatch", [])
1586         empty = re.compile(r"")
1587         sources: List[Path] = []
1588         expected = [
1589             Path(path / "b/dont_exclude/a.py"),
1590             Path(path / "b/dont_exclude/a.pyi"),
1591             Path(path / "b/exclude/a.py"),
1592             Path(path / "b/exclude/a.pyi"),
1593             Path(path / "b/.definitely_exclude/a.py"),
1594             Path(path / "b/.definitely_exclude/a.pyi"),
1595         ]
1596         this_abs = THIS_DIR.resolve()
1597         sources.extend(
1598             black.gen_python_files(
1599                 path.iterdir(),
1600                 this_abs,
1601                 re.compile(black.DEFAULT_INCLUDES),
1602                 empty,
1603                 None,
1604                 report,
1605                 gitignore,
1606             )
1607         )
1608         self.assertEqual(sorted(expected), sorted(sources))
1609
1610     def test_invalid_include_exclude(self) -> None:
1611         for option in ["--include", "--exclude"]:
1612             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1613
1614     def test_preserves_line_endings(self) -> None:
1615         with TemporaryDirectory() as workspace:
1616             test_file = Path(workspace) / "test.py"
1617             for nl in ["\n", "\r\n"]:
1618                 contents = nl.join(["def f(  ):", "    pass"])
1619                 test_file.write_bytes(contents.encode())
1620                 ff(test_file, write_back=black.WriteBack.YES)
1621                 updated_contents: bytes = test_file.read_bytes()
1622                 self.assertIn(nl.encode(), updated_contents)
1623                 if nl == "\n":
1624                     self.assertNotIn(b"\r\n", updated_contents)
1625
1626     def test_preserves_line_endings_via_stdin(self) -> None:
1627         for nl in ["\n", "\r\n"]:
1628             contents = nl.join(["def f(  ):", "    pass"])
1629             runner = BlackRunner()
1630             result = runner.invoke(
1631                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1632             )
1633             self.assertEqual(result.exit_code, 0)
1634             output = runner.stdout_bytes
1635             self.assertIn(nl.encode("utf8"), output)
1636             if nl == "\n":
1637                 self.assertNotIn(b"\r\n", output)
1638
1639     def test_assert_equivalent_different_asts(self) -> None:
1640         with self.assertRaises(AssertionError):
1641             black.assert_equivalent("{}", "None")
1642
1643     def test_symlink_out_of_root_directory(self) -> None:
1644         path = MagicMock()
1645         root = THIS_DIR.resolve()
1646         child = MagicMock()
1647         include = re.compile(black.DEFAULT_INCLUDES)
1648         exclude = re.compile(black.DEFAULT_EXCLUDES)
1649         report = black.Report()
1650         gitignore = PathSpec.from_lines("gitwildmatch", [])
1651         # `child` should behave like a symlink which resolved path is clearly
1652         # outside of the `root` directory.
1653         path.iterdir.return_value = [child]
1654         child.resolve.return_value = Path("/a/b/c")
1655         child.as_posix.return_value = "/a/b/c"
1656         child.is_symlink.return_value = True
1657         try:
1658             list(
1659                 black.gen_python_files(
1660                     path.iterdir(), root, include, exclude, None, report, gitignore
1661                 )
1662             )
1663         except ValueError as ve:
1664             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1665         path.iterdir.assert_called_once()
1666         child.resolve.assert_called_once()
1667         child.is_symlink.assert_called_once()
1668         # `child` should behave like a strange file which resolved path is clearly
1669         # outside of the `root` directory.
1670         child.is_symlink.return_value = False
1671         with self.assertRaises(ValueError):
1672             list(
1673                 black.gen_python_files(
1674                     path.iterdir(), root, include, exclude, None, report, gitignore
1675                 )
1676             )
1677         path.iterdir.assert_called()
1678         self.assertEqual(path.iterdir.call_count, 2)
1679         child.resolve.assert_called()
1680         self.assertEqual(child.resolve.call_count, 2)
1681         child.is_symlink.assert_called()
1682         self.assertEqual(child.is_symlink.call_count, 2)
1683
1684     def test_shhh_click(self) -> None:
1685         try:
1686             from click import _unicodefun  # type: ignore
1687         except ModuleNotFoundError:
1688             self.skipTest("Incompatible Click version")
1689         if not hasattr(_unicodefun, "_verify_python3_env"):
1690             self.skipTest("Incompatible Click version")
1691         # First, let's see if Click is crashing with a preferred ASCII charset.
1692         with patch("locale.getpreferredencoding") as gpe:
1693             gpe.return_value = "ASCII"
1694             with self.assertRaises(RuntimeError):
1695                 _unicodefun._verify_python3_env()
1696         # Now, let's silence Click...
1697         black.patch_click()
1698         # ...and confirm it's silent.
1699         with patch("locale.getpreferredencoding") as gpe:
1700             gpe.return_value = "ASCII"
1701             try:
1702                 _unicodefun._verify_python3_env()
1703             except RuntimeError as re:
1704                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1705
1706     def test_root_logger_not_used_directly(self) -> None:
1707         def fail(*args: Any, **kwargs: Any) -> None:
1708             self.fail("Record created with root logger")
1709
1710         with patch.multiple(
1711             logging.root,
1712             debug=fail,
1713             info=fail,
1714             warning=fail,
1715             error=fail,
1716             critical=fail,
1717             log=fail,
1718         ):
1719             ff(THIS_FILE)
1720
1721     def test_invalid_config_return_code(self) -> None:
1722         tmp_file = Path(black.dump_to_file())
1723         try:
1724             tmp_config = Path(black.dump_to_file())
1725             tmp_config.unlink()
1726             args = ["--config", str(tmp_config), str(tmp_file)]
1727             self.invokeBlack(args, exit_code=2, ignore_config=False)
1728         finally:
1729             tmp_file.unlink()
1730
1731     def test_parse_pyproject_toml(self) -> None:
1732         test_toml_file = THIS_DIR / "test.toml"
1733         config = black.parse_pyproject_toml(str(test_toml_file))
1734         self.assertEqual(config["verbose"], 1)
1735         self.assertEqual(config["check"], "no")
1736         self.assertEqual(config["diff"], "y")
1737         self.assertEqual(config["color"], True)
1738         self.assertEqual(config["line_length"], 79)
1739         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1740         self.assertEqual(config["exclude"], r"\.pyi?$")
1741         self.assertEqual(config["include"], r"\.py?$")
1742
1743     def test_read_pyproject_toml(self) -> None:
1744         test_toml_file = THIS_DIR / "test.toml"
1745         fake_ctx = FakeContext()
1746         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1747         config = fake_ctx.default_map
1748         self.assertEqual(config["verbose"], "1")
1749         self.assertEqual(config["check"], "no")
1750         self.assertEqual(config["diff"], "y")
1751         self.assertEqual(config["color"], "True")
1752         self.assertEqual(config["line_length"], "79")
1753         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1754         self.assertEqual(config["exclude"], r"\.pyi?$")
1755         self.assertEqual(config["include"], r"\.py?$")
1756
1757     def test_find_project_root(self) -> None:
1758         with TemporaryDirectory() as workspace:
1759             root = Path(workspace)
1760             test_dir = root / "test"
1761             test_dir.mkdir()
1762
1763             src_dir = root / "src"
1764             src_dir.mkdir()
1765
1766             root_pyproject = root / "pyproject.toml"
1767             root_pyproject.touch()
1768             src_pyproject = src_dir / "pyproject.toml"
1769             src_pyproject.touch()
1770             src_python = src_dir / "foo.py"
1771             src_python.touch()
1772
1773             self.assertEqual(
1774                 black.find_project_root((src_dir, test_dir)), root.resolve()
1775             )
1776             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1777             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1778
1779     def test_bpo_33660_workaround(self) -> None:
1780         if system() == "Windows":
1781             return
1782
1783         # https://bugs.python.org/issue33660
1784
1785         old_cwd = Path.cwd()
1786         try:
1787             root = Path("/")
1788             os.chdir(str(root))
1789             path = Path("workspace") / "project"
1790             report = black.Report(verbose=True)
1791             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1792             self.assertEqual(normalized_path, "workspace/project")
1793         finally:
1794             os.chdir(str(old_cwd))
1795
1796     def test_newline_comment_interaction(self) -> None:
1797         source = "class A:\\\r\n# type: ignore\n pass\n"
1798         output = black.format_str(source, mode=DEFAULT_MODE)
1799         black.assert_stable(source, output, mode=DEFAULT_MODE)
1800
1801
1802 with open(black.__file__, "r", encoding="utf-8") as _bf:
1803     black_source_lines = _bf.readlines()
1804
1805
1806 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
1807     """Show function calls `from black/__init__.py` as they happen.
1808
1809     Register this with `sys.settrace()` in a test you're debugging.
1810     """
1811     if event != "call":
1812         return tracefunc
1813
1814     stack = len(inspect.stack()) - 19
1815     stack *= 2
1816     filename = frame.f_code.co_filename
1817     lineno = frame.f_lineno
1818     func_sig_lineno = lineno - 1
1819     funcname = black_source_lines[func_sig_lineno].strip()
1820     while funcname.startswith("@"):
1821         func_sig_lineno += 1
1822         funcname = black_source_lines[func_sig_lineno].strip()
1823     if "black/__init__.py" in filename:
1824         print(f"{' ' * stack}{lineno}:{funcname}")
1825     return tracefunc
1826
1827
1828 if __name__ == "__main__":
1829     unittest.main(module="test_black")