]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Make pre-commit do less (#3838)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     Type,
29     TypeVar,
30     Union,
31 )
32 from unittest.mock import MagicMock, patch
33
34 import click
35 import pytest
36 from click import unstyle
37 from click.testing import CliRunner
38 from pathspec import PathSpec
39
40 import black
41 import black.files
42 from black import Feature, TargetVersion
43 from black import re_compile_maybe_verbose as compile_pattern
44 from black.cache import get_cache_dir, get_cache_file
45 from black.debug import DebugVisitor
46 from black.output import color_diff, diff
47 from black.report import Report
48
49 # Import other test classes
50 from tests.util import (
51     DATA_DIR,
52     DEFAULT_MODE,
53     DETERMINISTIC_HEADER,
54     PROJECT_ROOT,
55     PY36_VERSIONS,
56     THIS_DIR,
57     BlackBaseTestCase,
58     assert_format,
59     change_directory,
60     dump_to_stderr,
61     ff,
62     fs,
63     get_case_path,
64     read_data,
65     read_data_from_file,
66 )
67
68 THIS_FILE = Path(__file__)
69 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
70 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
71 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
72 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
73 T = TypeVar("T")
74 R = TypeVar("R")
75
76 # Match the time output in a diff, but nothing else
77 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.cache.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop() -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     loop = policy.new_event_loop()
94     asyncio.set_event_loop(loop)
95     try:
96         yield
97
98     finally:
99         loop.close()
100
101
102 class FakeContext(click.Context):
103     """A fake click Context for when calling functions that need it."""
104
105     def __init__(self) -> None:
106         self.default_map: Dict[str, Any] = {}
107         self.params: Dict[str, Any] = {}
108         # Dummy root, since most of the tests don't care about it
109         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
110
111
112 class FakeParameter(click.Parameter):
113     """A fake click Parameter for when calling functions that need it."""
114
115     def __init__(self) -> None:
116         pass
117
118
119 class BlackRunner(CliRunner):
120     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
121
122     def __init__(self) -> None:
123         super().__init__(mix_stderr=False)
124
125
126 def invokeBlack(
127     args: List[str], exit_code: int = 0, ignore_config: bool = True
128 ) -> None:
129     runner = BlackRunner()
130     if ignore_config:
131         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
132     result = runner.invoke(black.main, args, catch_exceptions=False)
133     assert result.stdout_bytes is not None
134     assert result.stderr_bytes is not None
135     msg = (
136         f"Failed with args: {args}\n"
137         f"stdout: {result.stdout_bytes.decode()!r}\n"
138         f"stderr: {result.stderr_bytes.decode()!r}\n"
139         f"exception: {result.exception}"
140     )
141     assert result.exit_code == exit_code, msg
142
143
144 class BlackTestCase(BlackBaseTestCase):
145     invokeBlack = staticmethod(invokeBlack)
146
147     def test_empty_ff(self) -> None:
148         expected = ""
149         tmp_file = Path(black.dump_to_file())
150         try:
151             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
152             actual = tmp_file.read_text(encoding="utf-8")
153         finally:
154             os.unlink(tmp_file)
155         self.assertFormatEqual(expected, actual)
156
157     @patch("black.dump_to_file", dump_to_stderr)
158     def test_one_empty_line(self) -> None:
159         mode = black.Mode(preview=True)
160         for nl in ["\n", "\r\n"]:
161             source = expected = nl
162             assert_format(source, expected, mode=mode)
163
164     def test_one_empty_line_ff(self) -> None:
165         mode = black.Mode(preview=True)
166         for nl in ["\n", "\r\n"]:
167             expected = nl
168             tmp_file = Path(black.dump_to_file(nl))
169             if system() == "Windows":
170                 # Writing files in text mode automatically uses the system newline,
171                 # but in this case we don't want this for testing reasons. See:
172                 # https://github.com/psf/black/pull/3348
173                 with open(tmp_file, "wb") as f:
174                     f.write(nl.encode("utf-8"))
175             try:
176                 self.assertFalse(
177                     ff(tmp_file, mode=mode, write_back=black.WriteBack.YES)
178                 )
179                 with open(tmp_file, "rb") as f:
180                     actual = f.read().decode("utf-8")
181             finally:
182                 os.unlink(tmp_file)
183             self.assertFormatEqual(expected, actual)
184
185     def test_experimental_string_processing_warns(self) -> None:
186         self.assertWarns(
187             black.mode.Deprecated, black.Mode, experimental_string_processing=True
188         )
189
190     def test_piping(self) -> None:
191         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
192         result = BlackRunner().invoke(
193             black.main,
194             [
195                 "-",
196                 "--fast",
197                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
198                 f"--config={EMPTY_CONFIG}",
199             ],
200             input=BytesIO(source.encode("utf-8")),
201         )
202         self.assertEqual(result.exit_code, 0)
203         self.assertFormatEqual(expected, result.output)
204         if source != result.output:
205             black.assert_equivalent(source, result.output)
206             black.assert_stable(source, result.output, DEFAULT_MODE)
207
208     def test_piping_diff(self) -> None:
209         diff_header = re.compile(
210             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d"
211             r"\+\d\d:\d\d"
212         )
213         source, _ = read_data("simple_cases", "expression.py")
214         expected, _ = read_data("simple_cases", "expression.diff")
215         args = [
216             "-",
217             "--fast",
218             f"--line-length={black.DEFAULT_LINE_LENGTH}",
219             "--diff",
220             f"--config={EMPTY_CONFIG}",
221         ]
222         result = BlackRunner().invoke(
223             black.main, args, input=BytesIO(source.encode("utf-8"))
224         )
225         self.assertEqual(result.exit_code, 0)
226         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
227         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
228         self.assertEqual(expected, actual)
229
230     def test_piping_diff_with_color(self) -> None:
231         source, _ = read_data("simple_cases", "expression.py")
232         args = [
233             "-",
234             "--fast",
235             f"--line-length={black.DEFAULT_LINE_LENGTH}",
236             "--diff",
237             "--color",
238             f"--config={EMPTY_CONFIG}",
239         ]
240         result = BlackRunner().invoke(
241             black.main, args, input=BytesIO(source.encode("utf-8"))
242         )
243         actual = result.output
244         # Again, the contents are checked in a different test, so only look for colors.
245         self.assertIn("\033[1m", actual)
246         self.assertIn("\033[36m", actual)
247         self.assertIn("\033[32m", actual)
248         self.assertIn("\033[31m", actual)
249         self.assertIn("\033[0m", actual)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def _test_wip(self) -> None:
253         source, expected = read_data("miscellaneous", "wip")
254         sys.settrace(tracefunc)
255         mode = replace(
256             DEFAULT_MODE,
257             experimental_string_processing=False,
258             target_versions={black.TargetVersion.PY38},
259         )
260         actual = fs(source, mode=mode)
261         sys.settrace(None)
262         self.assertFormatEqual(expected, actual)
263         black.assert_equivalent(source, actual)
264         black.assert_stable(source, actual, black.FileMode())
265
266     def test_pep_572_version_detection(self) -> None:
267         source, _ = read_data("py_38", "pep_572")
268         root = black.lib2to3_parse(source)
269         features = black.get_features_used(root)
270         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
271         versions = black.detect_target_versions(root)
272         self.assertIn(black.TargetVersion.PY38, versions)
273
274     def test_pep_695_version_detection(self) -> None:
275         for file in ("type_aliases", "type_params"):
276             source, _ = read_data("py_312", file)
277             root = black.lib2to3_parse(source)
278             features = black.get_features_used(root)
279             self.assertIn(black.Feature.TYPE_PARAMS, features)
280             versions = black.detect_target_versions(root)
281             self.assertIn(black.TargetVersion.PY312, versions)
282
283     def test_expression_ff(self) -> None:
284         source, expected = read_data("simple_cases", "expression.py")
285         tmp_file = Path(black.dump_to_file(source))
286         try:
287             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
288             actual = tmp_file.read_text(encoding="utf-8")
289         finally:
290             os.unlink(tmp_file)
291         self.assertFormatEqual(expected, actual)
292         with patch("black.dump_to_file", dump_to_stderr):
293             black.assert_equivalent(source, actual)
294             black.assert_stable(source, actual, DEFAULT_MODE)
295
296     def test_expression_diff(self) -> None:
297         source, _ = read_data("simple_cases", "expression.py")
298         expected, _ = read_data("simple_cases", "expression.diff")
299         tmp_file = Path(black.dump_to_file(source))
300         diff_header = re.compile(
301             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
302             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
303         )
304         try:
305             result = BlackRunner().invoke(
306                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
307             )
308             self.assertEqual(result.exit_code, 0)
309         finally:
310             os.unlink(tmp_file)
311         actual = result.output
312         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
313         if expected != actual:
314             dump = black.dump_to_file(actual)
315             msg = (
316                 "Expected diff isn't equal to the actual. If you made changes to"
317                 " expression.py and this is an anticipated difference, overwrite"
318                 f" tests/data/expression.diff with {dump}"
319             )
320             self.assertEqual(expected, actual, msg)
321
322     def test_expression_diff_with_color(self) -> None:
323         source, _ = read_data("simple_cases", "expression.py")
324         expected, _ = read_data("simple_cases", "expression.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         try:
327             result = BlackRunner().invoke(
328                 black.main,
329                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
330             )
331         finally:
332             os.unlink(tmp_file)
333         actual = result.output
334         # We check the contents of the diff in `test_expression_diff`. All
335         # we need to check here is that color codes exist in the result.
336         self.assertIn("\033[1m", actual)
337         self.assertIn("\033[36m", actual)
338         self.assertIn("\033[32m", actual)
339         self.assertIn("\033[31m", actual)
340         self.assertIn("\033[0m", actual)
341
342     def test_detect_pos_only_arguments(self) -> None:
343         source, _ = read_data("py_38", "pep_570")
344         root = black.lib2to3_parse(source)
345         features = black.get_features_used(root)
346         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
347         versions = black.detect_target_versions(root)
348         self.assertIn(black.TargetVersion.PY38, versions)
349
350     def test_detect_debug_f_strings(self) -> None:
351         root = black.lib2to3_parse("""f"{x=}" """)
352         features = black.get_features_used(root)
353         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
354         versions = black.detect_target_versions(root)
355         self.assertIn(black.TargetVersion.PY38, versions)
356
357         root = black.lib2to3_parse(
358             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
359         )
360         features = black.get_features_used(root)
361         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
362
363         # We don't yet support feature version detection in nested f-strings
364         root = black.lib2to3_parse(
365             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
366         )
367         features = black.get_features_used(root)
368         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
369
370     @patch("black.dump_to_file", dump_to_stderr)
371     def test_string_quotes(self) -> None:
372         source, expected = read_data("miscellaneous", "string_quotes")
373         mode = black.Mode(preview=True)
374         assert_format(source, expected, mode)
375         mode = replace(mode, string_normalization=False)
376         not_normalized = fs(source, mode=mode)
377         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
378         black.assert_equivalent(source, not_normalized)
379         black.assert_stable(source, not_normalized, mode=mode)
380
381     def test_skip_source_first_line(self) -> None:
382         source, _ = read_data("miscellaneous", "invalid_header")
383         tmp_file = Path(black.dump_to_file(source))
384         # Full source should fail (invalid syntax at header)
385         self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
386         # So, skipping the first line should work
387         result = BlackRunner().invoke(
388             black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
389         )
390         self.assertEqual(result.exit_code, 0)
391         actual = tmp_file.read_text(encoding="utf-8")
392         self.assertFormatEqual(source, actual)
393
394     def test_skip_source_first_line_when_mixing_newlines(self) -> None:
395         code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
396         expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
397         with TemporaryDirectory() as workspace:
398             test_file = Path(workspace) / "skip_header.py"
399             test_file.write_bytes(code_mixing_newlines)
400             mode = replace(DEFAULT_MODE, skip_source_first_line=True)
401             ff(test_file, mode=mode, write_back=black.WriteBack.YES)
402             self.assertEqual(test_file.read_bytes(), expected)
403
404     def test_skip_magic_trailing_comma(self) -> None:
405         source, _ = read_data("simple_cases", "expression")
406         expected, _ = read_data(
407             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
408         )
409         tmp_file = Path(black.dump_to_file(source))
410         diff_header = re.compile(
411             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
412             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
413         )
414         try:
415             result = BlackRunner().invoke(
416                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
417             )
418             self.assertEqual(result.exit_code, 0)
419         finally:
420             os.unlink(tmp_file)
421         actual = result.output
422         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
423         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
424         if expected != actual:
425             dump = black.dump_to_file(actual)
426             msg = (
427                 "Expected diff isn't equal to the actual. If you made changes to"
428                 " expression.py and this is an anticipated difference, overwrite"
429                 " tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff"
430                 f" with {dump}"
431             )
432             self.assertEqual(expected, actual, msg)
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_async_as_identifier(self) -> None:
436         source_path = get_case_path("miscellaneous", "async_as_identifier")
437         source, expected = read_data_from_file(source_path)
438         actual = fs(source)
439         self.assertFormatEqual(expected, actual)
440         major, minor = sys.version_info[:2]
441         if major < 3 or (major <= 3 and minor < 7):
442             black.assert_equivalent(source, actual)
443         black.assert_stable(source, actual, DEFAULT_MODE)
444         # ensure black can parse this when the target is 3.6
445         self.invokeBlack([str(source_path), "--target-version", "py36"])
446         # but not on 3.7, because async/await is no longer an identifier
447         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_python37(self) -> None:
451         source_path = get_case_path("py_37", "python37")
452         source, expected = read_data_from_file(source_path)
453         actual = fs(source)
454         self.assertFormatEqual(expected, actual)
455         major, minor = sys.version_info[:2]
456         if major > 3 or (major == 3 and minor >= 7):
457             black.assert_equivalent(source, actual)
458         black.assert_stable(source, actual, DEFAULT_MODE)
459         # ensure black can parse this when the target is 3.7
460         self.invokeBlack([str(source_path), "--target-version", "py37"])
461         # but not on 3.6, because we use async as a reserved keyword
462         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
463
464     def test_tab_comment_indentation(self) -> None:
465         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
466         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
467         self.assertFormatEqual(contents_spc, fs(contents_spc))
468         self.assertFormatEqual(contents_spc, fs(contents_tab))
469
470         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
471         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
472         self.assertFormatEqual(contents_spc, fs(contents_spc))
473         self.assertFormatEqual(contents_spc, fs(contents_tab))
474
475         # mixed tabs and spaces (valid Python 2 code)
476         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
477         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
478         self.assertFormatEqual(contents_spc, fs(contents_spc))
479         self.assertFormatEqual(contents_spc, fs(contents_tab))
480
481         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
482         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
483         self.assertFormatEqual(contents_spc, fs(contents_spc))
484         self.assertFormatEqual(contents_spc, fs(contents_tab))
485
486     def test_false_positive_symlink_output_issue_3384(self) -> None:
487         # Emulate the behavior when using the CLI (`black ./child  --verbose`), which
488         # involves patching some `pathlib.Path` methods. In particular, `is_dir` is
489         # patched only on its first call: when checking if "./child" is a directory it
490         # should return True. The "./child" folder exists relative to the cwd when
491         # running from CLI, but fails when running the tests because cwd is different
492         project_root = Path(THIS_DIR / "data" / "nested_gitignore_tests")
493         working_directory = project_root / "root"
494         target_abspath = working_directory / "child"
495         target_contents = (
496             src.relative_to(working_directory) for src in target_abspath.iterdir()
497         )
498
499         def mock_n_calls(responses: List[bool]) -> Callable[[], bool]:
500             def _mocked_calls() -> bool:
501                 if responses:
502                     return responses.pop(0)
503                 return False
504
505             return _mocked_calls
506
507         with patch("pathlib.Path.iterdir", return_value=target_contents), patch(
508             "pathlib.Path.cwd", return_value=working_directory
509         ), patch("pathlib.Path.is_dir", side_effect=mock_n_calls([True])):
510             ctx = FakeContext()
511             # Note that the root folder (project_root) isn't the folder
512             # named "root" (aka working_directory)
513             ctx.obj["root"] = project_root
514             report = MagicMock(verbose=True)
515             black.get_sources(
516                 ctx=ctx,
517                 src=("./child",),
518                 quiet=False,
519                 verbose=True,
520                 include=DEFAULT_INCLUDE,
521                 exclude=None,
522                 report=report,
523                 extend_exclude=None,
524                 force_exclude=None,
525                 stdin_filename=None,
526             )
527         assert not any(
528             mock_args[1].startswith("is a symbolic link that points outside")
529             for _, mock_args, _ in report.path_ignored.mock_calls
530         ), "A symbolic link was reported."
531         report.path_ignored.assert_called_once_with(
532             Path("root", "child", "b.py"), "matches a .gitignore file content"
533         )
534
535     def test_report_verbose(self) -> None:
536         report = Report(verbose=True)
537         out_lines = []
538         err_lines = []
539
540         def out(msg: str, **kwargs: Any) -> None:
541             out_lines.append(msg)
542
543         def err(msg: str, **kwargs: Any) -> None:
544             err_lines.append(msg)
545
546         with patch("black.output._out", out), patch("black.output._err", err):
547             report.done(Path("f1"), black.Changed.NO)
548             self.assertEqual(len(out_lines), 1)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
551             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
552             self.assertEqual(report.return_code, 0)
553             report.done(Path("f2"), black.Changed.YES)
554             self.assertEqual(len(out_lines), 2)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(out_lines[-1], "reformatted f2")
557             self.assertEqual(
558                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
559             )
560             report.done(Path("f3"), black.Changed.CACHED)
561             self.assertEqual(len(out_lines), 3)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(
564                 out_lines[-1], "f3 wasn't modified on disk since last run."
565             )
566             self.assertEqual(
567                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
568             )
569             self.assertEqual(report.return_code, 0)
570             report.check = True
571             self.assertEqual(report.return_code, 1)
572             report.check = False
573             report.failed(Path("e1"), "boom")
574             self.assertEqual(len(out_lines), 3)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
577             self.assertEqual(
578                 unstyle(str(report)),
579                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
580                 " reformat.",
581             )
582             self.assertEqual(report.return_code, 123)
583             report.done(Path("f3"), black.Changed.YES)
584             self.assertEqual(len(out_lines), 4)
585             self.assertEqual(len(err_lines), 1)
586             self.assertEqual(out_lines[-1], "reformatted f3")
587             self.assertEqual(
588                 unstyle(str(report)),
589                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
590                 " reformat.",
591             )
592             self.assertEqual(report.return_code, 123)
593             report.failed(Path("e2"), "boom")
594             self.assertEqual(len(out_lines), 4)
595             self.assertEqual(len(err_lines), 2)
596             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
600                 " reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.path_ignored(Path("wat"), "no match")
604             self.assertEqual(len(out_lines), 5)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(out_lines[-1], "wat ignored: no match")
607             self.assertEqual(
608                 unstyle(str(report)),
609                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
610                 " reformat.",
611             )
612             self.assertEqual(report.return_code, 123)
613             report.done(Path("f4"), black.Changed.NO)
614             self.assertEqual(len(out_lines), 6)
615             self.assertEqual(len(err_lines), 2)
616             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
617             self.assertEqual(
618                 unstyle(str(report)),
619                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
620                 " reformat.",
621             )
622             self.assertEqual(report.return_code, 123)
623             report.check = True
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files would be reformatted, 3 files would be left unchanged, 2"
627                 " files would fail to reformat.",
628             )
629             report.check = False
630             report.diff = True
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files would be reformatted, 3 files would be left unchanged, 2"
634                 " files would fail to reformat.",
635             )
636
637     def test_report_quiet(self) -> None:
638         report = Report(quiet=True)
639         out_lines = []
640         err_lines = []
641
642         def out(msg: str, **kwargs: Any) -> None:
643             out_lines.append(msg)
644
645         def err(msg: str, **kwargs: Any) -> None:
646             err_lines.append(msg)
647
648         with patch("black.output._out", out), patch("black.output._err", err):
649             report.done(Path("f1"), black.Changed.NO)
650             self.assertEqual(len(out_lines), 0)
651             self.assertEqual(len(err_lines), 0)
652             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
653             self.assertEqual(report.return_code, 0)
654             report.done(Path("f2"), black.Changed.YES)
655             self.assertEqual(len(out_lines), 0)
656             self.assertEqual(len(err_lines), 0)
657             self.assertEqual(
658                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
659             )
660             report.done(Path("f3"), black.Changed.CACHED)
661             self.assertEqual(len(out_lines), 0)
662             self.assertEqual(len(err_lines), 0)
663             self.assertEqual(
664                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
665             )
666             self.assertEqual(report.return_code, 0)
667             report.check = True
668             self.assertEqual(report.return_code, 1)
669             report.check = False
670             report.failed(Path("e1"), "boom")
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 1)
673             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
677                 " reformat.",
678             )
679             self.assertEqual(report.return_code, 123)
680             report.done(Path("f3"), black.Changed.YES)
681             self.assertEqual(len(out_lines), 0)
682             self.assertEqual(len(err_lines), 1)
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.failed(Path("e2"), "boom")
690             self.assertEqual(len(out_lines), 0)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
696                 " reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.path_ignored(Path("wat"), "no match")
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 2)
702             self.assertEqual(
703                 unstyle(str(report)),
704                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
705                 " reformat.",
706             )
707             self.assertEqual(report.return_code, 123)
708             report.done(Path("f4"), black.Changed.NO)
709             self.assertEqual(len(out_lines), 0)
710             self.assertEqual(len(err_lines), 2)
711             self.assertEqual(
712                 unstyle(str(report)),
713                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
714                 " reformat.",
715             )
716             self.assertEqual(report.return_code, 123)
717             report.check = True
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "2 files would be reformatted, 3 files would be left unchanged, 2"
721                 " files would fail to reformat.",
722             )
723             report.check = False
724             report.diff = True
725             self.assertEqual(
726                 unstyle(str(report)),
727                 "2 files would be reformatted, 3 files would be left unchanged, 2"
728                 " files would fail to reformat.",
729             )
730
731     def test_report_normal(self) -> None:
732         report = black.Report()
733         out_lines = []
734         err_lines = []
735
736         def out(msg: str, **kwargs: Any) -> None:
737             out_lines.append(msg)
738
739         def err(msg: str, **kwargs: Any) -> None:
740             err_lines.append(msg)
741
742         with patch("black.output._out", out), patch("black.output._err", err):
743             report.done(Path("f1"), black.Changed.NO)
744             self.assertEqual(len(out_lines), 0)
745             self.assertEqual(len(err_lines), 0)
746             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
747             self.assertEqual(report.return_code, 0)
748             report.done(Path("f2"), black.Changed.YES)
749             self.assertEqual(len(out_lines), 1)
750             self.assertEqual(len(err_lines), 0)
751             self.assertEqual(out_lines[-1], "reformatted f2")
752             self.assertEqual(
753                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
754             )
755             report.done(Path("f3"), black.Changed.CACHED)
756             self.assertEqual(len(out_lines), 1)
757             self.assertEqual(len(err_lines), 0)
758             self.assertEqual(out_lines[-1], "reformatted f2")
759             self.assertEqual(
760                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
761             )
762             self.assertEqual(report.return_code, 0)
763             report.check = True
764             self.assertEqual(report.return_code, 1)
765             report.check = False
766             report.failed(Path("e1"), "boom")
767             self.assertEqual(len(out_lines), 1)
768             self.assertEqual(len(err_lines), 1)
769             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
770             self.assertEqual(
771                 unstyle(str(report)),
772                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
773                 " reformat.",
774             )
775             self.assertEqual(report.return_code, 123)
776             report.done(Path("f3"), black.Changed.YES)
777             self.assertEqual(len(out_lines), 2)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(out_lines[-1], "reformatted f3")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
783                 " reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.failed(Path("e2"), "boom")
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 2)
789             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
793                 " reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.path_ignored(Path("wat"), "no match")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(
800                 unstyle(str(report)),
801                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
802                 " reformat.",
803             )
804             self.assertEqual(report.return_code, 123)
805             report.done(Path("f4"), black.Changed.NO)
806             self.assertEqual(len(out_lines), 2)
807             self.assertEqual(len(err_lines), 2)
808             self.assertEqual(
809                 unstyle(str(report)),
810                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
811                 " reformat.",
812             )
813             self.assertEqual(report.return_code, 123)
814             report.check = True
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "2 files would be reformatted, 3 files would be left unchanged, 2"
818                 " files would fail to reformat.",
819             )
820             report.check = False
821             report.diff = True
822             self.assertEqual(
823                 unstyle(str(report)),
824                 "2 files would be reformatted, 3 files would be left unchanged, 2"
825                 " files would fail to reformat.",
826             )
827
828     def test_lib2to3_parse(self) -> None:
829         with self.assertRaises(black.InvalidInput):
830             black.lib2to3_parse("invalid syntax")
831
832         straddling = "x + y"
833         black.lib2to3_parse(straddling)
834         black.lib2to3_parse(straddling, {TargetVersion.PY36})
835
836         py2_only = "print x"
837         with self.assertRaises(black.InvalidInput):
838             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
839
840         py3_only = "exec(x, end=y)"
841         black.lib2to3_parse(py3_only)
842         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
843
844     def test_get_features_used_decorator(self) -> None:
845         # Test the feature detection of new decorator syntax
846         # since this makes some test cases of test_get_features_used()
847         # fails if it fails, this is tested first so that a useful case
848         # is identified
849         simples, relaxed = read_data("miscellaneous", "decorators")
850         # skip explanation comments at the top of the file
851         for simple_test in simples.split("##")[1:]:
852             node = black.lib2to3_parse(simple_test)
853             decorator = str(node.children[0].children[0]).strip()
854             self.assertNotIn(
855                 Feature.RELAXED_DECORATORS,
856                 black.get_features_used(node),
857                 msg=(
858                     f"decorator '{decorator}' follows python<=3.8 syntax"
859                     "but is detected as 3.9+"
860                     # f"The full node is\n{node!r}"
861                 ),
862             )
863         # skip the '# output' comment at the top of the output part
864         for relaxed_test in relaxed.split("##")[1:]:
865             node = black.lib2to3_parse(relaxed_test)
866             decorator = str(node.children[0].children[0]).strip()
867             self.assertIn(
868                 Feature.RELAXED_DECORATORS,
869                 black.get_features_used(node),
870                 msg=(
871                     f"decorator '{decorator}' uses python3.9+ syntax"
872                     "but is detected as python<=3.8"
873                     # f"The full node is\n{node!r}"
874                 ),
875             )
876
877     def test_get_features_used(self) -> None:
878         node = black.lib2to3_parse("def f(*, arg): ...\n")
879         self.assertEqual(black.get_features_used(node), set())
880         node = black.lib2to3_parse("def f(*, arg,): ...\n")
881         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
882         node = black.lib2to3_parse("f(*arg,)\n")
883         self.assertEqual(
884             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
885         )
886         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
887         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
888         node = black.lib2to3_parse("123_456\n")
889         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
890         node = black.lib2to3_parse("123456\n")
891         self.assertEqual(black.get_features_used(node), set())
892         source, expected = read_data("simple_cases", "function")
893         node = black.lib2to3_parse(source)
894         expected_features = {
895             Feature.TRAILING_COMMA_IN_CALL,
896             Feature.TRAILING_COMMA_IN_DEF,
897             Feature.F_STRINGS,
898         }
899         self.assertEqual(black.get_features_used(node), expected_features)
900         node = black.lib2to3_parse(expected)
901         self.assertEqual(black.get_features_used(node), expected_features)
902         source, expected = read_data("simple_cases", "expression")
903         node = black.lib2to3_parse(source)
904         self.assertEqual(black.get_features_used(node), set())
905         node = black.lib2to3_parse(expected)
906         self.assertEqual(black.get_features_used(node), set())
907         node = black.lib2to3_parse("lambda a, /, b: ...")
908         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
909         node = black.lib2to3_parse("def fn(a, /, b): ...")
910         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
911         node = black.lib2to3_parse("def fn(): yield a, b")
912         self.assertEqual(black.get_features_used(node), set())
913         node = black.lib2to3_parse("def fn(): return a, b")
914         self.assertEqual(black.get_features_used(node), set())
915         node = black.lib2to3_parse("def fn(): yield *b, c")
916         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
917         node = black.lib2to3_parse("def fn(): return a, *b, c")
918         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
919         node = black.lib2to3_parse("x = a, *b, c")
920         self.assertEqual(black.get_features_used(node), set())
921         node = black.lib2to3_parse("x: Any = regular")
922         self.assertEqual(black.get_features_used(node), set())
923         node = black.lib2to3_parse("x: Any = (regular, regular)")
924         self.assertEqual(black.get_features_used(node), set())
925         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
926         self.assertEqual(black.get_features_used(node), set())
927         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
928         self.assertEqual(
929             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
930         )
931         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
932         self.assertEqual(black.get_features_used(node), set())
933         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
934         self.assertEqual(black.get_features_used(node), set())
935         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
936         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
937         node = black.lib2to3_parse("a[*b]")
938         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
939         node = black.lib2to3_parse("a[x, *y(), z] = t")
940         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
941         node = black.lib2to3_parse("def fn(*args: *T): pass")
942         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
943
944     def test_get_features_used_for_future_flags(self) -> None:
945         for src, features in [
946             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
947             (
948                 "from __future__ import (other, annotations)",
949                 {Feature.FUTURE_ANNOTATIONS},
950             ),
951             ("a = 1 + 2\nfrom something import annotations", set()),
952             ("from __future__ import x, y", set()),
953         ]:
954             with self.subTest(src=src, features=features):
955                 node = black.lib2to3_parse(src)
956                 future_imports = black.get_future_imports(node)
957                 self.assertEqual(
958                     black.get_features_used(node, future_imports=future_imports),
959                     features,
960                 )
961
962     def test_get_future_imports(self) -> None:
963         node = black.lib2to3_parse("\n")
964         self.assertEqual(set(), black.get_future_imports(node))
965         node = black.lib2to3_parse("from __future__ import black\n")
966         self.assertEqual({"black"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
968         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
969         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
970         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
971         node = black.lib2to3_parse(
972             "from __future__ import multiple\nfrom __future__ import imports\n"
973         )
974         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
975         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
976         self.assertEqual({"black"}, black.get_future_imports(node))
977         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
978         self.assertEqual({"black"}, black.get_future_imports(node))
979         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
980         self.assertEqual(set(), black.get_future_imports(node))
981         node = black.lib2to3_parse("from some.module import black\n")
982         self.assertEqual(set(), black.get_future_imports(node))
983         node = black.lib2to3_parse(
984             "from __future__ import unicode_literals as _unicode_literals"
985         )
986         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
987         node = black.lib2to3_parse(
988             "from __future__ import unicode_literals as _lol, print"
989         )
990         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
991
992     @pytest.mark.incompatible_with_mypyc
993     def test_debug_visitor(self) -> None:
994         source, _ = read_data("miscellaneous", "debug_visitor")
995         expected, _ = read_data("miscellaneous", "debug_visitor.out")
996         out_lines = []
997         err_lines = []
998
999         def out(msg: str, **kwargs: Any) -> None:
1000             out_lines.append(msg)
1001
1002         def err(msg: str, **kwargs: Any) -> None:
1003             err_lines.append(msg)
1004
1005         with patch("black.debug.out", out):
1006             DebugVisitor.show(source)
1007         actual = "\n".join(out_lines) + "\n"
1008         log_name = ""
1009         if expected != actual:
1010             log_name = black.dump_to_file(*out_lines)
1011         self.assertEqual(
1012             expected,
1013             actual,
1014             f"AST print out is different. Actual version dumped to {log_name}",
1015         )
1016
1017     def test_format_file_contents(self) -> None:
1018         mode = DEFAULT_MODE
1019         empty = ""
1020         with self.assertRaises(black.NothingChanged):
1021             black.format_file_contents(empty, mode=mode, fast=False)
1022         just_nl = "\n"
1023         with self.assertRaises(black.NothingChanged):
1024             black.format_file_contents(just_nl, mode=mode, fast=False)
1025         same = "j = [1, 2, 3]\n"
1026         with self.assertRaises(black.NothingChanged):
1027             black.format_file_contents(same, mode=mode, fast=False)
1028         different = "j = [1,2,3]"
1029         expected = same
1030         actual = black.format_file_contents(different, mode=mode, fast=False)
1031         self.assertEqual(expected, actual)
1032         invalid = "return if you can"
1033         with self.assertRaises(black.InvalidInput) as e:
1034             black.format_file_contents(invalid, mode=mode, fast=False)
1035         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1036
1037         mode = black.Mode(preview=True)
1038         just_crlf = "\r\n"
1039         with self.assertRaises(black.NothingChanged):
1040             black.format_file_contents(just_crlf, mode=mode, fast=False)
1041         just_whitespace_nl = "\n\t\n \n\t \n \t\n\n"
1042         actual = black.format_file_contents(just_whitespace_nl, mode=mode, fast=False)
1043         self.assertEqual("\n", actual)
1044         just_whitespace_crlf = "\r\n\t\r\n \r\n\t \r\n \t\r\n\r\n"
1045         actual = black.format_file_contents(just_whitespace_crlf, mode=mode, fast=False)
1046         self.assertEqual("\r\n", actual)
1047
1048     def test_endmarker(self) -> None:
1049         n = black.lib2to3_parse("\n")
1050         self.assertEqual(n.type, black.syms.file_input)
1051         self.assertEqual(len(n.children), 1)
1052         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1053
1054     @pytest.mark.incompatible_with_mypyc
1055     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1056     def test_assertFormatEqual(self) -> None:
1057         out_lines = []
1058         err_lines = []
1059
1060         def out(msg: str, **kwargs: Any) -> None:
1061             out_lines.append(msg)
1062
1063         def err(msg: str, **kwargs: Any) -> None:
1064             err_lines.append(msg)
1065
1066         with patch("black.output._out", out), patch("black.output._err", err):
1067             with self.assertRaises(AssertionError):
1068                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1069
1070         out_str = "".join(out_lines)
1071         self.assertIn("Expected tree:", out_str)
1072         self.assertIn("Actual tree:", out_str)
1073         self.assertEqual("".join(err_lines), "")
1074
1075     @event_loop()
1076     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1077     def test_works_in_mono_process_only_environment(self) -> None:
1078         with cache_dir() as workspace:
1079             for f in [
1080                 (workspace / "one.py").resolve(),
1081                 (workspace / "two.py").resolve(),
1082             ]:
1083                 f.write_text('print("hello")\n', encoding="utf-8")
1084             self.invokeBlack([str(workspace)])
1085
1086     @event_loop()
1087     def test_check_diff_use_together(self) -> None:
1088         with cache_dir():
1089             # Files which will be reformatted.
1090             src1 = get_case_path("miscellaneous", "string_quotes")
1091             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1092             # Files which will not be reformatted.
1093             src2 = get_case_path("simple_cases", "composition")
1094             self.invokeBlack([str(src2), "--diff", "--check"])
1095             # Multi file command.
1096             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1097
1098     def test_no_src_fails(self) -> None:
1099         with cache_dir():
1100             self.invokeBlack([], exit_code=1)
1101
1102     def test_src_and_code_fails(self) -> None:
1103         with cache_dir():
1104             self.invokeBlack([".", "-c", "0"], exit_code=1)
1105
1106     def test_broken_symlink(self) -> None:
1107         with cache_dir() as workspace:
1108             symlink = workspace / "broken_link.py"
1109             try:
1110                 symlink.symlink_to("nonexistent.py")
1111             except (OSError, NotImplementedError) as e:
1112                 self.skipTest(f"Can't create symlinks: {e}")
1113             self.invokeBlack([str(workspace.resolve())])
1114
1115     def test_single_file_force_pyi(self) -> None:
1116         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1117         contents, expected = read_data("miscellaneous", "force_pyi")
1118         with cache_dir() as workspace:
1119             path = (workspace / "file.py").resolve()
1120             path.write_text(contents, encoding="utf-8")
1121             self.invokeBlack([str(path), "--pyi"])
1122             actual = path.read_text(encoding="utf-8")
1123             # verify cache with --pyi is separate
1124             pyi_cache = black.read_cache(pyi_mode)
1125             self.assertIn(str(path), pyi_cache)
1126             normal_cache = black.read_cache(DEFAULT_MODE)
1127             self.assertNotIn(str(path), normal_cache)
1128         self.assertFormatEqual(expected, actual)
1129         black.assert_equivalent(contents, actual)
1130         black.assert_stable(contents, actual, pyi_mode)
1131
1132     @event_loop()
1133     def test_multi_file_force_pyi(self) -> None:
1134         reg_mode = DEFAULT_MODE
1135         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1136         contents, expected = read_data("miscellaneous", "force_pyi")
1137         with cache_dir() as workspace:
1138             paths = [
1139                 (workspace / "file1.py").resolve(),
1140                 (workspace / "file2.py").resolve(),
1141             ]
1142             for path in paths:
1143                 path.write_text(contents, encoding="utf-8")
1144             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1145             for path in paths:
1146                 actual = path.read_text(encoding="utf-8")
1147                 self.assertEqual(actual, expected)
1148             # verify cache with --pyi is separate
1149             pyi_cache = black.read_cache(pyi_mode)
1150             normal_cache = black.read_cache(reg_mode)
1151             for path in paths:
1152                 self.assertIn(str(path), pyi_cache)
1153                 self.assertNotIn(str(path), normal_cache)
1154
1155     def test_pipe_force_pyi(self) -> None:
1156         source, expected = read_data("miscellaneous", "force_pyi")
1157         result = CliRunner().invoke(
1158             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf-8"))
1159         )
1160         self.assertEqual(result.exit_code, 0)
1161         actual = result.output
1162         self.assertFormatEqual(actual, expected)
1163
1164     def test_single_file_force_py36(self) -> None:
1165         reg_mode = DEFAULT_MODE
1166         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1167         source, expected = read_data("miscellaneous", "force_py36")
1168         with cache_dir() as workspace:
1169             path = (workspace / "file.py").resolve()
1170             path.write_text(source, encoding="utf-8")
1171             self.invokeBlack([str(path), *PY36_ARGS])
1172             actual = path.read_text(encoding="utf-8")
1173             # verify cache with --target-version is separate
1174             py36_cache = black.read_cache(py36_mode)
1175             self.assertIn(str(path), py36_cache)
1176             normal_cache = black.read_cache(reg_mode)
1177             self.assertNotIn(str(path), normal_cache)
1178         self.assertEqual(actual, expected)
1179
1180     @event_loop()
1181     def test_multi_file_force_py36(self) -> None:
1182         reg_mode = DEFAULT_MODE
1183         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1184         source, expected = read_data("miscellaneous", "force_py36")
1185         with cache_dir() as workspace:
1186             paths = [
1187                 (workspace / "file1.py").resolve(),
1188                 (workspace / "file2.py").resolve(),
1189             ]
1190             for path in paths:
1191                 path.write_text(source, encoding="utf-8")
1192             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1193             for path in paths:
1194                 actual = path.read_text(encoding="utf-8")
1195                 self.assertEqual(actual, expected)
1196             # verify cache with --target-version is separate
1197             pyi_cache = black.read_cache(py36_mode)
1198             normal_cache = black.read_cache(reg_mode)
1199             for path in paths:
1200                 self.assertIn(str(path), pyi_cache)
1201                 self.assertNotIn(str(path), normal_cache)
1202
1203     def test_pipe_force_py36(self) -> None:
1204         source, expected = read_data("miscellaneous", "force_py36")
1205         result = CliRunner().invoke(
1206             black.main,
1207             ["-", "-q", "--target-version=py36"],
1208             input=BytesIO(source.encode("utf-8")),
1209         )
1210         self.assertEqual(result.exit_code, 0)
1211         actual = result.output
1212         self.assertFormatEqual(actual, expected)
1213
1214     @pytest.mark.incompatible_with_mypyc
1215     def test_reformat_one_with_stdin(self) -> None:
1216         with patch(
1217             "black.format_stdin_to_stdout",
1218             return_value=lambda *args, **kwargs: black.Changed.YES,
1219         ) as fsts:
1220             report = MagicMock()
1221             path = Path("-")
1222             black.reformat_one(
1223                 path,
1224                 fast=True,
1225                 write_back=black.WriteBack.YES,
1226                 mode=DEFAULT_MODE,
1227                 report=report,
1228             )
1229             fsts.assert_called_once()
1230             report.done.assert_called_with(path, black.Changed.YES)
1231
1232     @pytest.mark.incompatible_with_mypyc
1233     def test_reformat_one_with_stdin_filename(self) -> None:
1234         with patch(
1235             "black.format_stdin_to_stdout",
1236             return_value=lambda *args, **kwargs: black.Changed.YES,
1237         ) as fsts:
1238             report = MagicMock()
1239             p = "foo.py"
1240             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1241             expected = Path(p)
1242             black.reformat_one(
1243                 path,
1244                 fast=True,
1245                 write_back=black.WriteBack.YES,
1246                 mode=DEFAULT_MODE,
1247                 report=report,
1248             )
1249             fsts.assert_called_once_with(
1250                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1251             )
1252             # __BLACK_STDIN_FILENAME__ should have been stripped
1253             report.done.assert_called_with(expected, black.Changed.YES)
1254
1255     @pytest.mark.incompatible_with_mypyc
1256     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1257         with patch(
1258             "black.format_stdin_to_stdout",
1259             return_value=lambda *args, **kwargs: black.Changed.YES,
1260         ) as fsts:
1261             report = MagicMock()
1262             p = "foo.pyi"
1263             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1264             expected = Path(p)
1265             black.reformat_one(
1266                 path,
1267                 fast=True,
1268                 write_back=black.WriteBack.YES,
1269                 mode=DEFAULT_MODE,
1270                 report=report,
1271             )
1272             fsts.assert_called_once_with(
1273                 fast=True,
1274                 write_back=black.WriteBack.YES,
1275                 mode=replace(DEFAULT_MODE, is_pyi=True),
1276             )
1277             # __BLACK_STDIN_FILENAME__ should have been stripped
1278             report.done.assert_called_with(expected, black.Changed.YES)
1279
1280     @pytest.mark.incompatible_with_mypyc
1281     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1282         with patch(
1283             "black.format_stdin_to_stdout",
1284             return_value=lambda *args, **kwargs: black.Changed.YES,
1285         ) as fsts:
1286             report = MagicMock()
1287             p = "foo.ipynb"
1288             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1289             expected = Path(p)
1290             black.reformat_one(
1291                 path,
1292                 fast=True,
1293                 write_back=black.WriteBack.YES,
1294                 mode=DEFAULT_MODE,
1295                 report=report,
1296             )
1297             fsts.assert_called_once_with(
1298                 fast=True,
1299                 write_back=black.WriteBack.YES,
1300                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1301             )
1302             # __BLACK_STDIN_FILENAME__ should have been stripped
1303             report.done.assert_called_with(expected, black.Changed.YES)
1304
1305     @pytest.mark.incompatible_with_mypyc
1306     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1307         with patch(
1308             "black.format_stdin_to_stdout",
1309             return_value=lambda *args, **kwargs: black.Changed.YES,
1310         ) as fsts:
1311             report = MagicMock()
1312             # Even with an existing file, since we are forcing stdin, black
1313             # should output to stdout and not modify the file inplace
1314             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1315             # Make sure is_file actually returns True
1316             self.assertTrue(p.is_file())
1317             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1318             expected = Path(p)
1319             black.reformat_one(
1320                 path,
1321                 fast=True,
1322                 write_back=black.WriteBack.YES,
1323                 mode=DEFAULT_MODE,
1324                 report=report,
1325             )
1326             fsts.assert_called_once()
1327             # __BLACK_STDIN_FILENAME__ should have been stripped
1328             report.done.assert_called_with(expected, black.Changed.YES)
1329
1330     def test_reformat_one_with_stdin_empty(self) -> None:
1331         cases = [
1332             ("", ""),
1333             ("\n", "\n"),
1334             ("\r\n", "\r\n"),
1335             (" \t", ""),
1336             (" \t\n\t ", "\n"),
1337             (" \t\r\n\t ", "\r\n"),
1338         ]
1339
1340         def _new_wrapper(
1341             output: io.StringIO, io_TextIOWrapper: Type[io.TextIOWrapper]
1342         ) -> Callable[[Any, Any], io.TextIOWrapper]:
1343             def get_output(*args: Any, **kwargs: Any) -> io.TextIOWrapper:
1344                 if args == (sys.stdout.buffer,):
1345                     # It's `format_stdin_to_stdout()` calling `io.TextIOWrapper()`,
1346                     # return our mock object.
1347                     return output
1348                 # It's something else (i.e. `decode_bytes()`) calling
1349                 # `io.TextIOWrapper()`, pass through to the original implementation.
1350                 # See discussion in https://github.com/psf/black/pull/2489
1351                 return io_TextIOWrapper(*args, **kwargs)
1352
1353             return get_output
1354
1355         mode = black.Mode(preview=True)
1356         for content, expected in cases:
1357             output = io.StringIO()
1358             io_TextIOWrapper = io.TextIOWrapper
1359
1360             with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1361                 try:
1362                     black.format_stdin_to_stdout(
1363                         fast=True,
1364                         content=content,
1365                         write_back=black.WriteBack.YES,
1366                         mode=mode,
1367                     )
1368                 except io.UnsupportedOperation:
1369                     pass  # StringIO does not support detach
1370                 assert output.getvalue() == expected
1371
1372         # An empty string is the only test case for `preview=False`
1373         output = io.StringIO()
1374         io_TextIOWrapper = io.TextIOWrapper
1375         with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1376             try:
1377                 black.format_stdin_to_stdout(
1378                     fast=True,
1379                     content="",
1380                     write_back=black.WriteBack.YES,
1381                     mode=DEFAULT_MODE,
1382                 )
1383             except io.UnsupportedOperation:
1384                 pass  # StringIO does not support detach
1385             assert output.getvalue() == ""
1386
1387     def test_invalid_cli_regex(self) -> None:
1388         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1389             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1390
1391     def test_required_version_matches_version(self) -> None:
1392         self.invokeBlack(
1393             ["--required-version", black.__version__, "-c", "0"],
1394             exit_code=0,
1395             ignore_config=True,
1396         )
1397
1398     def test_required_version_matches_partial_version(self) -> None:
1399         self.invokeBlack(
1400             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1401             exit_code=0,
1402             ignore_config=True,
1403         )
1404
1405     def test_required_version_does_not_match_on_minor_version(self) -> None:
1406         self.invokeBlack(
1407             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1408             exit_code=1,
1409             ignore_config=True,
1410         )
1411
1412     def test_required_version_does_not_match_version(self) -> None:
1413         result = BlackRunner().invoke(
1414             black.main,
1415             ["--required-version", "20.99b", "-c", "0"],
1416         )
1417         self.assertEqual(result.exit_code, 1)
1418         self.assertIn("required version", result.stderr)
1419
1420     def test_preserves_line_endings(self) -> None:
1421         with TemporaryDirectory() as workspace:
1422             test_file = Path(workspace) / "test.py"
1423             for nl in ["\n", "\r\n"]:
1424                 contents = nl.join(["def f(  ):", "    pass"])
1425                 test_file.write_bytes(contents.encode())
1426                 ff(test_file, write_back=black.WriteBack.YES)
1427                 updated_contents: bytes = test_file.read_bytes()
1428                 self.assertIn(nl.encode(), updated_contents)
1429                 if nl == "\n":
1430                     self.assertNotIn(b"\r\n", updated_contents)
1431
1432     def test_preserves_line_endings_via_stdin(self) -> None:
1433         for nl in ["\n", "\r\n"]:
1434             contents = nl.join(["def f(  ):", "    pass"])
1435             runner = BlackRunner()
1436             result = runner.invoke(
1437                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf-8"))
1438             )
1439             self.assertEqual(result.exit_code, 0)
1440             output = result.stdout_bytes
1441             self.assertIn(nl.encode("utf-8"), output)
1442             if nl == "\n":
1443                 self.assertNotIn(b"\r\n", output)
1444
1445     def test_normalize_line_endings(self) -> None:
1446         with TemporaryDirectory() as workspace:
1447             test_file = Path(workspace) / "test.py"
1448             for data, expected in (
1449                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1450                 (b"l\nl\r\n ", b"l\nl\n"),
1451             ):
1452                 test_file.write_bytes(data)
1453                 ff(test_file, write_back=black.WriteBack.YES)
1454                 self.assertEqual(test_file.read_bytes(), expected)
1455
1456     def test_assert_equivalent_different_asts(self) -> None:
1457         with self.assertRaises(AssertionError):
1458             black.assert_equivalent("{}", "None")
1459
1460     def test_root_logger_not_used_directly(self) -> None:
1461         def fail(*args: Any, **kwargs: Any) -> None:
1462             self.fail("Record created with root logger")
1463
1464         with patch.multiple(
1465             logging.root,
1466             debug=fail,
1467             info=fail,
1468             warning=fail,
1469             error=fail,
1470             critical=fail,
1471             log=fail,
1472         ):
1473             ff(THIS_DIR / "util.py")
1474
1475     def test_invalid_config_return_code(self) -> None:
1476         tmp_file = Path(black.dump_to_file())
1477         try:
1478             tmp_config = Path(black.dump_to_file())
1479             tmp_config.unlink()
1480             args = ["--config", str(tmp_config), str(tmp_file)]
1481             self.invokeBlack(args, exit_code=2, ignore_config=False)
1482         finally:
1483             tmp_file.unlink()
1484
1485     def test_parse_pyproject_toml(self) -> None:
1486         test_toml_file = THIS_DIR / "test.toml"
1487         config = black.parse_pyproject_toml(str(test_toml_file))
1488         self.assertEqual(config["verbose"], 1)
1489         self.assertEqual(config["check"], "no")
1490         self.assertEqual(config["diff"], "y")
1491         self.assertEqual(config["color"], True)
1492         self.assertEqual(config["line_length"], 79)
1493         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1494         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1495         self.assertEqual(config["exclude"], r"\.pyi?$")
1496         self.assertEqual(config["include"], r"\.py?$")
1497
1498     def test_parse_pyproject_toml_project_metadata(self) -> None:
1499         for test_toml, expected in [
1500             ("only_black_pyproject.toml", ["py310"]),
1501             ("only_metadata_pyproject.toml", ["py37", "py38", "py39", "py310"]),
1502             ("neither_pyproject.toml", None),
1503             ("both_pyproject.toml", ["py310"]),
1504         ]:
1505             test_toml_file = THIS_DIR / "data" / "project_metadata" / test_toml
1506             config = black.parse_pyproject_toml(str(test_toml_file))
1507             self.assertEqual(config.get("target_version"), expected)
1508
1509     def test_infer_target_version(self) -> None:
1510         for version, expected in [
1511             ("3.6", [TargetVersion.PY36]),
1512             ("3.11.0rc1", [TargetVersion.PY311]),
1513             (">=3.10", [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312]),
1514             (
1515                 ">=3.10.6",
1516                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1517             ),
1518             ("<3.6", [TargetVersion.PY33, TargetVersion.PY34, TargetVersion.PY35]),
1519             (">3.7,<3.10", [TargetVersion.PY38, TargetVersion.PY39]),
1520             (
1521                 ">3.7,!=3.8,!=3.9",
1522                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1523             ),
1524             (
1525                 "> 3.9.4, != 3.10.3",
1526                 [
1527                     TargetVersion.PY39,
1528                     TargetVersion.PY310,
1529                     TargetVersion.PY311,
1530                     TargetVersion.PY312,
1531                 ],
1532             ),
1533             (
1534                 "!=3.3,!=3.4",
1535                 [
1536                     TargetVersion.PY35,
1537                     TargetVersion.PY36,
1538                     TargetVersion.PY37,
1539                     TargetVersion.PY38,
1540                     TargetVersion.PY39,
1541                     TargetVersion.PY310,
1542                     TargetVersion.PY311,
1543                     TargetVersion.PY312,
1544                 ],
1545             ),
1546             (
1547                 "==3.*",
1548                 [
1549                     TargetVersion.PY33,
1550                     TargetVersion.PY34,
1551                     TargetVersion.PY35,
1552                     TargetVersion.PY36,
1553                     TargetVersion.PY37,
1554                     TargetVersion.PY38,
1555                     TargetVersion.PY39,
1556                     TargetVersion.PY310,
1557                     TargetVersion.PY311,
1558                     TargetVersion.PY312,
1559                 ],
1560             ),
1561             ("==3.8.*", [TargetVersion.PY38]),
1562             (None, None),
1563             ("", None),
1564             ("invalid", None),
1565             ("==invalid", None),
1566             (">3.9,!=invalid", None),
1567             ("3", None),
1568             ("3.2", None),
1569             ("2.7.18", None),
1570             ("==2.7", None),
1571             (">3.10,<3.11", None),
1572         ]:
1573             test_toml = {"project": {"requires-python": version}}
1574             result = black.files.infer_target_version(test_toml)
1575             self.assertEqual(result, expected)
1576
1577     def test_read_pyproject_toml(self) -> None:
1578         test_toml_file = THIS_DIR / "test.toml"
1579         fake_ctx = FakeContext()
1580         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1581         config = fake_ctx.default_map
1582         self.assertEqual(config["verbose"], "1")
1583         self.assertEqual(config["check"], "no")
1584         self.assertEqual(config["diff"], "y")
1585         self.assertEqual(config["color"], "True")
1586         self.assertEqual(config["line_length"], "79")
1587         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1588         self.assertEqual(config["exclude"], r"\.pyi?$")
1589         self.assertEqual(config["include"], r"\.py?$")
1590
1591     def test_read_pyproject_toml_from_stdin(self) -> None:
1592         with TemporaryDirectory() as workspace:
1593             root = Path(workspace)
1594
1595             src_dir = root / "src"
1596             src_dir.mkdir()
1597
1598             src_pyproject = src_dir / "pyproject.toml"
1599             src_pyproject.touch()
1600
1601             test_toml_content = (THIS_DIR / "test.toml").read_text(encoding="utf-8")
1602             src_pyproject.write_text(test_toml_content, encoding="utf-8")
1603
1604             src_python = src_dir / "foo.py"
1605             src_python.touch()
1606
1607             fake_ctx = FakeContext()
1608             fake_ctx.params["src"] = ("-",)
1609             fake_ctx.params["stdin_filename"] = str(src_python)
1610
1611             with change_directory(root):
1612                 black.read_pyproject_toml(fake_ctx, FakeParameter(), None)
1613
1614             config = fake_ctx.default_map
1615             self.assertEqual(config["verbose"], "1")
1616             self.assertEqual(config["check"], "no")
1617             self.assertEqual(config["diff"], "y")
1618             self.assertEqual(config["color"], "True")
1619             self.assertEqual(config["line_length"], "79")
1620             self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1621             self.assertEqual(config["exclude"], r"\.pyi?$")
1622             self.assertEqual(config["include"], r"\.py?$")
1623
1624     @pytest.mark.incompatible_with_mypyc
1625     def test_find_project_root(self) -> None:
1626         with TemporaryDirectory() as workspace:
1627             root = Path(workspace)
1628             test_dir = root / "test"
1629             test_dir.mkdir()
1630
1631             src_dir = root / "src"
1632             src_dir.mkdir()
1633
1634             root_pyproject = root / "pyproject.toml"
1635             root_pyproject.touch()
1636             src_pyproject = src_dir / "pyproject.toml"
1637             src_pyproject.touch()
1638             src_python = src_dir / "foo.py"
1639             src_python.touch()
1640
1641             self.assertEqual(
1642                 black.find_project_root((src_dir, test_dir)),
1643                 (root.resolve(), "pyproject.toml"),
1644             )
1645             self.assertEqual(
1646                 black.find_project_root((src_dir,)),
1647                 (src_dir.resolve(), "pyproject.toml"),
1648             )
1649             self.assertEqual(
1650                 black.find_project_root((src_python,)),
1651                 (src_dir.resolve(), "pyproject.toml"),
1652             )
1653
1654             with change_directory(test_dir):
1655                 self.assertEqual(
1656                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1657                     (src_dir.resolve(), "pyproject.toml"),
1658                 )
1659
1660     @patch(
1661         "black.files.find_user_pyproject_toml",
1662     )
1663     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1664         find_user_pyproject_toml.side_effect = RuntimeError()
1665
1666         with redirect_stderr(io.StringIO()) as stderr:
1667             result = black.files.find_pyproject_toml(
1668                 path_search_start=(str(Path.cwd().root),)
1669             )
1670
1671         assert result is None
1672         err = stderr.getvalue()
1673         assert "Ignoring user configuration" in err
1674
1675     @patch(
1676         "black.files.find_user_pyproject_toml",
1677         black.files.find_user_pyproject_toml.__wrapped__,
1678     )
1679     def test_find_user_pyproject_toml_linux(self) -> None:
1680         if system() == "Windows":
1681             return
1682
1683         # Test if XDG_CONFIG_HOME is checked
1684         with TemporaryDirectory() as workspace:
1685             tmp_user_config = Path(workspace) / "black"
1686             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1687                 self.assertEqual(
1688                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1689                 )
1690
1691         # Test fallback for XDG_CONFIG_HOME
1692         with patch.dict("os.environ"):
1693             os.environ.pop("XDG_CONFIG_HOME", None)
1694             fallback_user_config = Path("~/.config").expanduser() / "black"
1695             self.assertEqual(
1696                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1697             )
1698
1699     def test_find_user_pyproject_toml_windows(self) -> None:
1700         if system() != "Windows":
1701             return
1702
1703         user_config_path = Path.home() / ".black"
1704         self.assertEqual(
1705             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1706         )
1707
1708     def test_bpo_33660_workaround(self) -> None:
1709         if system() == "Windows":
1710             return
1711
1712         # https://bugs.python.org/issue33660
1713         root = Path("/")
1714         with change_directory(root):
1715             path = Path("workspace") / "project"
1716             report = black.Report(verbose=True)
1717             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1718             self.assertEqual(normalized_path, "workspace/project")
1719
1720     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1721         if system() != "Windows":
1722             return
1723
1724         with TemporaryDirectory() as workspace:
1725             root = Path(workspace)
1726             junction_dir = root / "junction"
1727             junction_target_outside_of_root = root / ".."
1728             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1729
1730             report = black.Report(verbose=True)
1731             normalized_path = black.normalize_path_maybe_ignore(
1732                 junction_dir, root, report
1733             )
1734             # Manually delete for Python < 3.8
1735             os.system(f"rmdir {junction_dir}")
1736
1737             self.assertEqual(normalized_path, None)
1738
1739     def test_newline_comment_interaction(self) -> None:
1740         source = "class A:\\\r\n# type: ignore\n pass\n"
1741         output = black.format_str(source, mode=DEFAULT_MODE)
1742         black.assert_stable(source, output, mode=DEFAULT_MODE)
1743
1744     def test_bpo_2142_workaround(self) -> None:
1745         # https://bugs.python.org/issue2142
1746
1747         source, _ = read_data("miscellaneous", "missing_final_newline")
1748         # read_data adds a trailing newline
1749         source = source.rstrip()
1750         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1751         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1752         diff_header = re.compile(
1753             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1754             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
1755         )
1756         try:
1757             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1758             self.assertEqual(result.exit_code, 0)
1759         finally:
1760             os.unlink(tmp_file)
1761         actual = result.output
1762         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1763         self.assertEqual(actual, expected)
1764
1765     @staticmethod
1766     def compare_results(
1767         result: click.testing.Result, expected_value: str, expected_exit_code: int
1768     ) -> None:
1769         """Helper method to test the value and exit code of a click Result."""
1770         assert (
1771             result.output == expected_value
1772         ), "The output did not match the expected value."
1773         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1774
1775     def test_code_option(self) -> None:
1776         """Test the code option with no changes."""
1777         code = 'print("Hello world")\n'
1778         args = ["--code", code]
1779         result = CliRunner().invoke(black.main, args)
1780
1781         self.compare_results(result, code, 0)
1782
1783     def test_code_option_changed(self) -> None:
1784         """Test the code option when changes are required."""
1785         code = "print('hello world')"
1786         formatted = black.format_str(code, mode=DEFAULT_MODE)
1787
1788         args = ["--code", code]
1789         result = CliRunner().invoke(black.main, args)
1790
1791         self.compare_results(result, formatted, 0)
1792
1793     def test_code_option_check(self) -> None:
1794         """Test the code option when check is passed."""
1795         args = ["--check", "--code", 'print("Hello world")\n']
1796         result = CliRunner().invoke(black.main, args)
1797         self.compare_results(result, "", 0)
1798
1799     def test_code_option_check_changed(self) -> None:
1800         """Test the code option when changes are required, and check is passed."""
1801         args = ["--check", "--code", "print('hello world')"]
1802         result = CliRunner().invoke(black.main, args)
1803         self.compare_results(result, "", 1)
1804
1805     def test_code_option_diff(self) -> None:
1806         """Test the code option when diff is passed."""
1807         code = "print('hello world')"
1808         formatted = black.format_str(code, mode=DEFAULT_MODE)
1809         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1810
1811         args = ["--diff", "--code", code]
1812         result = CliRunner().invoke(black.main, args)
1813
1814         # Remove time from diff
1815         output = DIFF_TIME.sub("", result.output)
1816
1817         assert output == result_diff, "The output did not match the expected value."
1818         assert result.exit_code == 0, "The exit code is incorrect."
1819
1820     def test_code_option_color_diff(self) -> None:
1821         """Test the code option when color and diff are passed."""
1822         code = "print('hello world')"
1823         formatted = black.format_str(code, mode=DEFAULT_MODE)
1824
1825         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1826         result_diff = color_diff(result_diff)
1827
1828         args = ["--diff", "--color", "--code", code]
1829         result = CliRunner().invoke(black.main, args)
1830
1831         # Remove time from diff
1832         output = DIFF_TIME.sub("", result.output)
1833
1834         assert output == result_diff, "The output did not match the expected value."
1835         assert result.exit_code == 0, "The exit code is incorrect."
1836
1837     @pytest.mark.incompatible_with_mypyc
1838     def test_code_option_safe(self) -> None:
1839         """Test that the code option throws an error when the sanity checks fail."""
1840         # Patch black.assert_equivalent to ensure the sanity checks fail
1841         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1842             code = 'print("Hello world")'
1843             error_msg = f"{code}\nerror: cannot format <string>: \n"
1844
1845             args = ["--safe", "--code", code]
1846             result = CliRunner().invoke(black.main, args)
1847
1848             self.compare_results(result, error_msg, 123)
1849
1850     def test_code_option_fast(self) -> None:
1851         """Test that the code option ignores errors when the sanity checks fail."""
1852         # Patch black.assert_equivalent to ensure the sanity checks fail
1853         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1854             code = 'print("Hello world")'
1855             formatted = black.format_str(code, mode=DEFAULT_MODE)
1856
1857             args = ["--fast", "--code", code]
1858             result = CliRunner().invoke(black.main, args)
1859
1860             self.compare_results(result, formatted, 0)
1861
1862     @pytest.mark.incompatible_with_mypyc
1863     def test_code_option_config(self) -> None:
1864         """
1865         Test that the code option finds the pyproject.toml in the current directory.
1866         """
1867         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1868             args = ["--code", "print"]
1869             # This is the only directory known to contain a pyproject.toml
1870             with change_directory(PROJECT_ROOT):
1871                 CliRunner().invoke(black.main, args)
1872                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1873
1874             assert (
1875                 len(parse.mock_calls) >= 1
1876             ), "Expected config parse to be called with the current directory."
1877
1878             _, call_args, _ = parse.mock_calls[0]
1879             assert (
1880                 call_args[0].lower() == str(pyproject_path).lower()
1881             ), "Incorrect config loaded."
1882
1883     @pytest.mark.incompatible_with_mypyc
1884     def test_code_option_parent_config(self) -> None:
1885         """
1886         Test that the code option finds the pyproject.toml in the parent directory.
1887         """
1888         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1889             with change_directory(THIS_DIR):
1890                 args = ["--code", "print"]
1891                 CliRunner().invoke(black.main, args)
1892
1893                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1894                 assert (
1895                     len(parse.mock_calls) >= 1
1896                 ), "Expected config parse to be called with the current directory."
1897
1898                 _, call_args, _ = parse.mock_calls[0]
1899                 assert (
1900                     call_args[0].lower() == str(pyproject_path).lower()
1901                 ), "Incorrect config loaded."
1902
1903     def test_for_handled_unexpected_eof_error(self) -> None:
1904         """
1905         Test that an unexpected EOF SyntaxError is nicely presented.
1906         """
1907         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1908             black.lib2to3_parse("print(", {})
1909
1910         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1911
1912     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1913         with pytest.raises(AssertionError) as err:
1914             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1915
1916         err.match("--safe")
1917         # Unfortunately the SyntaxError message has changed in newer versions so we
1918         # can't match it directly.
1919         err.match("invalid character")
1920         err.match(r"\(<unknown>, line 1\)")
1921
1922
1923 class TestCaching:
1924     def test_get_cache_dir(
1925         self,
1926         tmp_path: Path,
1927         monkeypatch: pytest.MonkeyPatch,
1928     ) -> None:
1929         # Create multiple cache directories
1930         workspace1 = tmp_path / "ws1"
1931         workspace1.mkdir()
1932         workspace2 = tmp_path / "ws2"
1933         workspace2.mkdir()
1934
1935         # Force user_cache_dir to use the temporary directory for easier assertions
1936         patch_user_cache_dir = patch(
1937             target="black.cache.user_cache_dir",
1938             autospec=True,
1939             return_value=str(workspace1),
1940         )
1941
1942         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1943         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1944         with patch_user_cache_dir:
1945             assert get_cache_dir() == workspace1
1946
1947         # If it is set, use the path provided in the env var.
1948         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1949         assert get_cache_dir() == workspace2
1950
1951     def test_cache_broken_file(self) -> None:
1952         mode = DEFAULT_MODE
1953         with cache_dir() as workspace:
1954             cache_file = get_cache_file(mode)
1955             cache_file.write_text("this is not a pickle", encoding="utf-8")
1956             assert black.read_cache(mode) == {}
1957             src = (workspace / "test.py").resolve()
1958             src.write_text("print('hello')", encoding="utf-8")
1959             invokeBlack([str(src)])
1960             cache = black.read_cache(mode)
1961             assert str(src) in cache
1962
1963     def test_cache_single_file_already_cached(self) -> None:
1964         mode = DEFAULT_MODE
1965         with cache_dir() as workspace:
1966             src = (workspace / "test.py").resolve()
1967             src.write_text("print('hello')", encoding="utf-8")
1968             black.write_cache({}, [src], mode)
1969             invokeBlack([str(src)])
1970             assert src.read_text(encoding="utf-8") == "print('hello')"
1971
1972     @event_loop()
1973     def test_cache_multiple_files(self) -> None:
1974         mode = DEFAULT_MODE
1975         with cache_dir() as workspace, patch(
1976             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1977         ):
1978             one = (workspace / "one.py").resolve()
1979             one.write_text("print('hello')", encoding="utf-8")
1980             two = (workspace / "two.py").resolve()
1981             two.write_text("print('hello')", encoding="utf-8")
1982             black.write_cache({}, [one], mode)
1983             invokeBlack([str(workspace)])
1984             assert one.read_text(encoding="utf-8") == "print('hello')"
1985             assert two.read_text(encoding="utf-8") == 'print("hello")\n'
1986             cache = black.read_cache(mode)
1987             assert str(one) in cache
1988             assert str(two) in cache
1989
1990     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1991     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1992         mode = DEFAULT_MODE
1993         with cache_dir() as workspace:
1994             src = (workspace / "test.py").resolve()
1995             src.write_text("print('hello')", encoding="utf-8")
1996             with patch("black.read_cache") as read_cache, patch(
1997                 "black.write_cache"
1998             ) as write_cache:
1999                 cmd = [str(src), "--diff"]
2000                 if color:
2001                     cmd.append("--color")
2002                 invokeBlack(cmd)
2003                 cache_file = get_cache_file(mode)
2004                 assert cache_file.exists() is False
2005                 write_cache.assert_not_called()
2006                 read_cache.assert_not_called()
2007
2008     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
2009     @event_loop()
2010     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
2011         with cache_dir() as workspace:
2012             for tag in range(0, 4):
2013                 src = (workspace / f"test{tag}.py").resolve()
2014                 src.write_text("print('hello')", encoding="utf-8")
2015             with patch(
2016                 "black.concurrency.Manager", wraps=multiprocessing.Manager
2017             ) as mgr:
2018                 cmd = ["--diff", str(workspace)]
2019                 if color:
2020                     cmd.append("--color")
2021                 invokeBlack(cmd, exit_code=0)
2022                 # this isn't quite doing what we want, but if it _isn't_
2023                 # called then we cannot be using the lock it provides
2024                 mgr.assert_called()
2025
2026     def test_no_cache_when_stdin(self) -> None:
2027         mode = DEFAULT_MODE
2028         with cache_dir():
2029             result = CliRunner().invoke(
2030                 black.main, ["-"], input=BytesIO(b"print('hello')")
2031             )
2032             assert not result.exit_code
2033             cache_file = get_cache_file(mode)
2034             assert not cache_file.exists()
2035
2036     def test_read_cache_no_cachefile(self) -> None:
2037         mode = DEFAULT_MODE
2038         with cache_dir():
2039             assert black.read_cache(mode) == {}
2040
2041     def test_write_cache_read_cache(self) -> None:
2042         mode = DEFAULT_MODE
2043         with cache_dir() as workspace:
2044             src = (workspace / "test.py").resolve()
2045             src.touch()
2046             black.write_cache({}, [src], mode)
2047             cache = black.read_cache(mode)
2048             assert str(src) in cache
2049             assert cache[str(src)] == black.get_cache_info(src)
2050
2051     def test_filter_cached(self) -> None:
2052         with TemporaryDirectory() as workspace:
2053             path = Path(workspace)
2054             uncached = (path / "uncached").resolve()
2055             cached = (path / "cached").resolve()
2056             cached_but_changed = (path / "changed").resolve()
2057             uncached.touch()
2058             cached.touch()
2059             cached_but_changed.touch()
2060             cache = {
2061                 str(cached): black.get_cache_info(cached),
2062                 str(cached_but_changed): (0.0, 0),
2063             }
2064             todo, done = black.cache.filter_cached(
2065                 cache, {uncached, cached, cached_but_changed}
2066             )
2067             assert todo == {uncached, cached_but_changed}
2068             assert done == {cached}
2069
2070     def test_write_cache_creates_directory_if_needed(self) -> None:
2071         mode = DEFAULT_MODE
2072         with cache_dir(exists=False) as workspace:
2073             assert not workspace.exists()
2074             black.write_cache({}, [], mode)
2075             assert workspace.exists()
2076
2077     @event_loop()
2078     def test_failed_formatting_does_not_get_cached(self) -> None:
2079         mode = DEFAULT_MODE
2080         with cache_dir() as workspace, patch(
2081             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
2082         ):
2083             failing = (workspace / "failing.py").resolve()
2084             failing.write_text("not actually python", encoding="utf-8")
2085             clean = (workspace / "clean.py").resolve()
2086             clean.write_text('print("hello")\n', encoding="utf-8")
2087             invokeBlack([str(workspace)], exit_code=123)
2088             cache = black.read_cache(mode)
2089             assert str(failing) not in cache
2090             assert str(clean) in cache
2091
2092     def test_write_cache_write_fail(self) -> None:
2093         mode = DEFAULT_MODE
2094         with cache_dir(), patch.object(Path, "open") as mock:
2095             mock.side_effect = OSError
2096             black.write_cache({}, [], mode)
2097
2098     def test_read_cache_line_lengths(self) -> None:
2099         mode = DEFAULT_MODE
2100         short_mode = replace(DEFAULT_MODE, line_length=1)
2101         with cache_dir() as workspace:
2102             path = (workspace / "file.py").resolve()
2103             path.touch()
2104             black.write_cache({}, [path], mode)
2105             one = black.read_cache(mode)
2106             assert str(path) in one
2107             two = black.read_cache(short_mode)
2108             assert str(path) not in two
2109
2110
2111 def assert_collected_sources(
2112     src: Sequence[Union[str, Path]],
2113     expected: Sequence[Union[str, Path]],
2114     *,
2115     ctx: Optional[FakeContext] = None,
2116     exclude: Optional[str] = None,
2117     include: Optional[str] = None,
2118     extend_exclude: Optional[str] = None,
2119     force_exclude: Optional[str] = None,
2120     stdin_filename: Optional[str] = None,
2121 ) -> None:
2122     gs_src = tuple(str(Path(s)) for s in src)
2123     gs_expected = [Path(s) for s in expected]
2124     gs_exclude = None if exclude is None else compile_pattern(exclude)
2125     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
2126     gs_extend_exclude = (
2127         None if extend_exclude is None else compile_pattern(extend_exclude)
2128     )
2129     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
2130     collected = black.get_sources(
2131         ctx=ctx or FakeContext(),
2132         src=gs_src,
2133         quiet=False,
2134         verbose=False,
2135         include=gs_include,
2136         exclude=gs_exclude,
2137         extend_exclude=gs_extend_exclude,
2138         force_exclude=gs_force_exclude,
2139         report=black.Report(),
2140         stdin_filename=stdin_filename,
2141     )
2142     assert sorted(collected) == sorted(gs_expected)
2143
2144
2145 class TestFileCollection:
2146     def test_include_exclude(self) -> None:
2147         path = THIS_DIR / "data" / "include_exclude_tests"
2148         src = [path]
2149         expected = [
2150             Path(path / "b/dont_exclude/a.py"),
2151             Path(path / "b/dont_exclude/a.pyi"),
2152         ]
2153         assert_collected_sources(
2154             src,
2155             expected,
2156             include=r"\.pyi?$",
2157             exclude=r"/exclude/|/\.definitely_exclude/",
2158         )
2159
2160     def test_gitignore_used_as_default(self) -> None:
2161         base = Path(DATA_DIR / "include_exclude_tests")
2162         expected = [
2163             base / "b/.definitely_exclude/a.py",
2164             base / "b/.definitely_exclude/a.pyi",
2165         ]
2166         src = [base / "b/"]
2167         ctx = FakeContext()
2168         ctx.obj["root"] = base
2169         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
2170
2171     def test_gitignore_used_on_multiple_sources(self) -> None:
2172         root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
2173         expected = [
2174             root / "dir1" / "b.py",
2175             root / "dir2" / "b.py",
2176         ]
2177         ctx = FakeContext()
2178         ctx.obj["root"] = root
2179         src = [root / "dir1", root / "dir2"]
2180         assert_collected_sources(src, expected, ctx=ctx)
2181
2182     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2183     def test_exclude_for_issue_1572(self) -> None:
2184         # Exclude shouldn't touch files that were explicitly given to Black through the
2185         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
2186         # https://github.com/psf/black/issues/1572
2187         path = DATA_DIR / "include_exclude_tests"
2188         src = [path / "b/exclude/a.py"]
2189         expected = [path / "b/exclude/a.py"]
2190         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2191
2192     def test_gitignore_exclude(self) -> None:
2193         path = THIS_DIR / "data" / "include_exclude_tests"
2194         include = re.compile(r"\.pyi?$")
2195         exclude = re.compile(r"")
2196         report = black.Report()
2197         gitignore = PathSpec.from_lines(
2198             "gitwildmatch", ["exclude/", ".definitely_exclude"]
2199         )
2200         sources: List[Path] = []
2201         expected = [
2202             Path(path / "b/dont_exclude/a.py"),
2203             Path(path / "b/dont_exclude/a.pyi"),
2204         ]
2205         this_abs = THIS_DIR.resolve()
2206         sources.extend(
2207             black.gen_python_files(
2208                 path.iterdir(),
2209                 this_abs,
2210                 include,
2211                 exclude,
2212                 None,
2213                 None,
2214                 report,
2215                 {path: gitignore},
2216                 verbose=False,
2217                 quiet=False,
2218             )
2219         )
2220         assert sorted(expected) == sorted(sources)
2221
2222     def test_nested_gitignore(self) -> None:
2223         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
2224         include = re.compile(r"\.pyi?$")
2225         exclude = re.compile(r"")
2226         root_gitignore = black.files.get_gitignore(path)
2227         report = black.Report()
2228         expected: List[Path] = [
2229             Path(path / "x.py"),
2230             Path(path / "root/b.py"),
2231             Path(path / "root/c.py"),
2232             Path(path / "root/child/c.py"),
2233         ]
2234         this_abs = THIS_DIR.resolve()
2235         sources = list(
2236             black.gen_python_files(
2237                 path.iterdir(),
2238                 this_abs,
2239                 include,
2240                 exclude,
2241                 None,
2242                 None,
2243                 report,
2244                 {path: root_gitignore},
2245                 verbose=False,
2246                 quiet=False,
2247             )
2248         )
2249         assert sorted(expected) == sorted(sources)
2250
2251     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2252         # https://github.com/psf/black/issues/2598
2253         path = Path(DATA_DIR / "nested_gitignore_tests")
2254         src = Path(path / "root" / "child")
2255         expected = [src / "a.py", src / "c.py"]
2256         assert_collected_sources([src], expected)
2257
2258     def test_invalid_gitignore(self) -> None:
2259         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2260         empty_config = path / "pyproject.toml"
2261         result = BlackRunner().invoke(
2262             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2263         )
2264         assert result.exit_code == 1
2265         assert result.stderr_bytes is not None
2266
2267         gitignore = path / ".gitignore"
2268         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2269
2270     def test_invalid_nested_gitignore(self) -> None:
2271         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2272         empty_config = path / "pyproject.toml"
2273         result = BlackRunner().invoke(
2274             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2275         )
2276         assert result.exit_code == 1
2277         assert result.stderr_bytes is not None
2278
2279         gitignore = path / "a" / ".gitignore"
2280         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2281
2282     def test_gitignore_that_ignores_subfolders(self) -> None:
2283         # If gitignore with */* is in root
2284         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
2285         expected = [root / "b.py"]
2286         ctx = FakeContext()
2287         ctx.obj["root"] = root
2288         assert_collected_sources([root], expected, ctx=ctx)
2289
2290         # If .gitignore with */* is nested
2291         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2292         expected = [
2293             root / "a.py",
2294             root / "subdir" / "b.py",
2295         ]
2296         ctx = FakeContext()
2297         ctx.obj["root"] = root
2298         assert_collected_sources([root], expected, ctx=ctx)
2299
2300         # If command is executed from outer dir
2301         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2302         target = root / "subdir"
2303         expected = [target / "b.py"]
2304         ctx = FakeContext()
2305         ctx.obj["root"] = root
2306         assert_collected_sources([target], expected, ctx=ctx)
2307
2308     def test_empty_include(self) -> None:
2309         path = DATA_DIR / "include_exclude_tests"
2310         src = [path]
2311         expected = [
2312             Path(path / "b/exclude/a.pie"),
2313             Path(path / "b/exclude/a.py"),
2314             Path(path / "b/exclude/a.pyi"),
2315             Path(path / "b/dont_exclude/a.pie"),
2316             Path(path / "b/dont_exclude/a.py"),
2317             Path(path / "b/dont_exclude/a.pyi"),
2318             Path(path / "b/.definitely_exclude/a.pie"),
2319             Path(path / "b/.definitely_exclude/a.py"),
2320             Path(path / "b/.definitely_exclude/a.pyi"),
2321             Path(path / ".gitignore"),
2322             Path(path / "pyproject.toml"),
2323         ]
2324         # Setting exclude explicitly to an empty string to block .gitignore usage.
2325         assert_collected_sources(src, expected, include="", exclude="")
2326
2327     def test_extend_exclude(self) -> None:
2328         path = DATA_DIR / "include_exclude_tests"
2329         src = [path]
2330         expected = [
2331             Path(path / "b/exclude/a.py"),
2332             Path(path / "b/dont_exclude/a.py"),
2333         ]
2334         assert_collected_sources(
2335             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2336         )
2337
2338     @pytest.mark.incompatible_with_mypyc
2339     def test_symlink_out_of_root_directory(self) -> None:
2340         path = MagicMock()
2341         root = THIS_DIR.resolve()
2342         child = MagicMock()
2343         include = re.compile(black.DEFAULT_INCLUDES)
2344         exclude = re.compile(black.DEFAULT_EXCLUDES)
2345         report = black.Report()
2346         gitignore = PathSpec.from_lines("gitwildmatch", [])
2347         # `child` should behave like a symlink which resolved path is clearly
2348         # outside of the `root` directory.
2349         path.iterdir.return_value = [child]
2350         child.resolve.return_value = Path("/a/b/c")
2351         child.as_posix.return_value = "/a/b/c"
2352         try:
2353             list(
2354                 black.gen_python_files(
2355                     path.iterdir(),
2356                     root,
2357                     include,
2358                     exclude,
2359                     None,
2360                     None,
2361                     report,
2362                     {path: gitignore},
2363                     verbose=False,
2364                     quiet=False,
2365                 )
2366             )
2367         except ValueError as ve:
2368             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2369         path.iterdir.assert_called_once()
2370         child.resolve.assert_called_once()
2371
2372     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2373     def test_get_sources_with_stdin(self) -> None:
2374         src = ["-"]
2375         expected = ["-"]
2376         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2377
2378     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2379     def test_get_sources_with_stdin_filename(self) -> None:
2380         src = ["-"]
2381         stdin_filename = str(THIS_DIR / "data/collections.py")
2382         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2383         assert_collected_sources(
2384             src,
2385             expected,
2386             exclude=r"/exclude/a\.py",
2387             stdin_filename=stdin_filename,
2388         )
2389
2390     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2391     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2392         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2393         # file being passed directly. This is the same as
2394         # test_exclude_for_issue_1572
2395         path = DATA_DIR / "include_exclude_tests"
2396         src = ["-"]
2397         stdin_filename = str(path / "b/exclude/a.py")
2398         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2399         assert_collected_sources(
2400             src,
2401             expected,
2402             exclude=r"/exclude/|a\.py",
2403             stdin_filename=stdin_filename,
2404         )
2405
2406     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2407     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2408         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2409         # file being passed directly. This is the same as
2410         # test_exclude_for_issue_1572
2411         src = ["-"]
2412         path = THIS_DIR / "data" / "include_exclude_tests"
2413         stdin_filename = str(path / "b/exclude/a.py")
2414         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2415         assert_collected_sources(
2416             src,
2417             expected,
2418             extend_exclude=r"/exclude/|a\.py",
2419             stdin_filename=stdin_filename,
2420         )
2421
2422     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2423     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2424         # Force exclude should exclude the file when passing it through
2425         # stdin_filename
2426         path = THIS_DIR / "data" / "include_exclude_tests"
2427         stdin_filename = str(path / "b/exclude/a.py")
2428         assert_collected_sources(
2429             src=["-"],
2430             expected=[],
2431             force_exclude=r"/exclude/|a\.py",
2432             stdin_filename=stdin_filename,
2433         )
2434
2435
2436 try:
2437     with open(black.__file__, "r", encoding="utf-8") as _bf:
2438         black_source_lines = _bf.readlines()
2439 except UnicodeDecodeError:
2440     if not black.COMPILED:
2441         raise
2442
2443
2444 def tracefunc(
2445     frame: types.FrameType, event: str, arg: Any
2446 ) -> Callable[[types.FrameType, str, Any], Any]:
2447     """Show function calls `from black/__init__.py` as they happen.
2448
2449     Register this with `sys.settrace()` in a test you're debugging.
2450     """
2451     if event != "call":
2452         return tracefunc
2453
2454     stack = len(inspect.stack()) - 19
2455     stack *= 2
2456     filename = frame.f_code.co_filename
2457     lineno = frame.f_lineno
2458     func_sig_lineno = lineno - 1
2459     funcname = black_source_lines[func_sig_lineno].strip()
2460     while funcname.startswith("@"):
2461         func_sig_lineno += 1
2462         funcname = black_source_lines[func_sig_lineno].strip()
2463     if "black/__init__.py" in filename:
2464         print(f"{' ' * stack}{lineno}:{funcname}")
2465     return tracefunc