]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix CI for Click typing issue (#3770)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     Type,
29     TypeVar,
30     Union,
31 )
32 from unittest.mock import MagicMock, patch
33
34 import click
35 import pytest
36 from click import unstyle
37 from click.testing import CliRunner
38 from pathspec import PathSpec
39
40 import black
41 import black.files
42 from black import Feature, TargetVersion
43 from black import re_compile_maybe_verbose as compile_pattern
44 from black.cache import get_cache_dir, get_cache_file
45 from black.debug import DebugVisitor
46 from black.output import color_diff, diff
47 from black.report import Report
48
49 # Import other test classes
50 from tests.util import (
51     DATA_DIR,
52     DEFAULT_MODE,
53     DETERMINISTIC_HEADER,
54     PROJECT_ROOT,
55     PY36_VERSIONS,
56     THIS_DIR,
57     BlackBaseTestCase,
58     assert_format,
59     change_directory,
60     dump_to_stderr,
61     ff,
62     fs,
63     get_case_path,
64     read_data,
65     read_data_from_file,
66 )
67
68 THIS_FILE = Path(__file__)
69 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
70 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
71 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
72 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
73 T = TypeVar("T")
74 R = TypeVar("R")
75
76 # Match the time output in a diff, but nothing else
77 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.cache.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop() -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     loop = policy.new_event_loop()
94     asyncio.set_event_loop(loop)
95     try:
96         yield
97
98     finally:
99         loop.close()
100
101
102 class FakeContext(click.Context):
103     """A fake click Context for when calling functions that need it."""
104
105     def __init__(self) -> None:
106         self.default_map: Dict[str, Any] = {}
107         self.params: Dict[str, Any] = {}
108         # Dummy root, since most of the tests don't care about it
109         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
110
111
112 class FakeParameter(click.Parameter):
113     """A fake click Parameter for when calling functions that need it."""
114
115     def __init__(self) -> None:
116         pass
117
118
119 class BlackRunner(CliRunner):
120     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
121
122     def __init__(self) -> None:
123         super().__init__(mix_stderr=False)
124
125
126 def invokeBlack(
127     args: List[str], exit_code: int = 0, ignore_config: bool = True
128 ) -> None:
129     runner = BlackRunner()
130     if ignore_config:
131         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
132     result = runner.invoke(black.main, args, catch_exceptions=False)
133     assert result.stdout_bytes is not None
134     assert result.stderr_bytes is not None
135     msg = (
136         f"Failed with args: {args}\n"
137         f"stdout: {result.stdout_bytes.decode()!r}\n"
138         f"stderr: {result.stderr_bytes.decode()!r}\n"
139         f"exception: {result.exception}"
140     )
141     assert result.exit_code == exit_code, msg
142
143
144 class BlackTestCase(BlackBaseTestCase):
145     invokeBlack = staticmethod(invokeBlack)
146
147     def test_empty_ff(self) -> None:
148         expected = ""
149         tmp_file = Path(black.dump_to_file())
150         try:
151             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
152             actual = tmp_file.read_text(encoding="utf-8")
153         finally:
154             os.unlink(tmp_file)
155         self.assertFormatEqual(expected, actual)
156
157     @patch("black.dump_to_file", dump_to_stderr)
158     def test_one_empty_line(self) -> None:
159         mode = black.Mode(preview=True)
160         for nl in ["\n", "\r\n"]:
161             source = expected = nl
162             assert_format(source, expected, mode=mode)
163
164     def test_one_empty_line_ff(self) -> None:
165         mode = black.Mode(preview=True)
166         for nl in ["\n", "\r\n"]:
167             expected = nl
168             tmp_file = Path(black.dump_to_file(nl))
169             if system() == "Windows":
170                 # Writing files in text mode automatically uses the system newline,
171                 # but in this case we don't want this for testing reasons. See:
172                 # https://github.com/psf/black/pull/3348
173                 with open(tmp_file, "wb") as f:
174                     f.write(nl.encode("utf-8"))
175             try:
176                 self.assertFalse(
177                     ff(tmp_file, mode=mode, write_back=black.WriteBack.YES)
178                 )
179                 with open(tmp_file, "rb") as f:
180                     actual = f.read().decode("utf-8")
181             finally:
182                 os.unlink(tmp_file)
183             self.assertFormatEqual(expected, actual)
184
185     def test_experimental_string_processing_warns(self) -> None:
186         self.assertWarns(
187             black.mode.Deprecated, black.Mode, experimental_string_processing=True
188         )
189
190     def test_piping(self) -> None:
191         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
192         result = BlackRunner().invoke(
193             black.main,
194             [
195                 "-",
196                 "--fast",
197                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
198                 f"--config={EMPTY_CONFIG}",
199             ],
200             input=BytesIO(source.encode("utf-8")),
201         )
202         self.assertEqual(result.exit_code, 0)
203         self.assertFormatEqual(expected, result.output)
204         if source != result.output:
205             black.assert_equivalent(source, result.output)
206             black.assert_stable(source, result.output, DEFAULT_MODE)
207
208     def test_piping_diff(self) -> None:
209         diff_header = re.compile(
210             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d"
211             r"\+\d\d:\d\d"
212         )
213         source, _ = read_data("simple_cases", "expression.py")
214         expected, _ = read_data("simple_cases", "expression.diff")
215         args = [
216             "-",
217             "--fast",
218             f"--line-length={black.DEFAULT_LINE_LENGTH}",
219             "--diff",
220             f"--config={EMPTY_CONFIG}",
221         ]
222         result = BlackRunner().invoke(
223             black.main, args, input=BytesIO(source.encode("utf-8"))
224         )
225         self.assertEqual(result.exit_code, 0)
226         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
227         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
228         self.assertEqual(expected, actual)
229
230     def test_piping_diff_with_color(self) -> None:
231         source, _ = read_data("simple_cases", "expression.py")
232         args = [
233             "-",
234             "--fast",
235             f"--line-length={black.DEFAULT_LINE_LENGTH}",
236             "--diff",
237             "--color",
238             f"--config={EMPTY_CONFIG}",
239         ]
240         result = BlackRunner().invoke(
241             black.main, args, input=BytesIO(source.encode("utf-8"))
242         )
243         actual = result.output
244         # Again, the contents are checked in a different test, so only look for colors.
245         self.assertIn("\033[1m", actual)
246         self.assertIn("\033[36m", actual)
247         self.assertIn("\033[32m", actual)
248         self.assertIn("\033[31m", actual)
249         self.assertIn("\033[0m", actual)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def _test_wip(self) -> None:
253         source, expected = read_data("miscellaneous", "wip")
254         sys.settrace(tracefunc)
255         mode = replace(
256             DEFAULT_MODE,
257             experimental_string_processing=False,
258             target_versions={black.TargetVersion.PY38},
259         )
260         actual = fs(source, mode=mode)
261         sys.settrace(None)
262         self.assertFormatEqual(expected, actual)
263         black.assert_equivalent(source, actual)
264         black.assert_stable(source, actual, black.FileMode())
265
266     def test_pep_572_version_detection(self) -> None:
267         source, _ = read_data("py_38", "pep_572")
268         root = black.lib2to3_parse(source)
269         features = black.get_features_used(root)
270         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
271         versions = black.detect_target_versions(root)
272         self.assertIn(black.TargetVersion.PY38, versions)
273
274     def test_pep_695_version_detection(self) -> None:
275         for file in ("type_aliases", "type_params"):
276             source, _ = read_data("py_312", file)
277             root = black.lib2to3_parse(source)
278             features = black.get_features_used(root)
279             self.assertIn(black.Feature.TYPE_PARAMS, features)
280             versions = black.detect_target_versions(root)
281             self.assertIn(black.TargetVersion.PY312, versions)
282
283     def test_expression_ff(self) -> None:
284         source, expected = read_data("simple_cases", "expression.py")
285         tmp_file = Path(black.dump_to_file(source))
286         try:
287             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
288             actual = tmp_file.read_text(encoding="utf-8")
289         finally:
290             os.unlink(tmp_file)
291         self.assertFormatEqual(expected, actual)
292         with patch("black.dump_to_file", dump_to_stderr):
293             black.assert_equivalent(source, actual)
294             black.assert_stable(source, actual, DEFAULT_MODE)
295
296     def test_expression_diff(self) -> None:
297         source, _ = read_data("simple_cases", "expression.py")
298         expected, _ = read_data("simple_cases", "expression.diff")
299         tmp_file = Path(black.dump_to_file(source))
300         diff_header = re.compile(
301             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
302             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
303         )
304         try:
305             result = BlackRunner().invoke(
306                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
307             )
308             self.assertEqual(result.exit_code, 0)
309         finally:
310             os.unlink(tmp_file)
311         actual = result.output
312         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
313         if expected != actual:
314             dump = black.dump_to_file(actual)
315             msg = (
316                 "Expected diff isn't equal to the actual. If you made changes to"
317                 " expression.py and this is an anticipated difference, overwrite"
318                 f" tests/data/expression.diff with {dump}"
319             )
320             self.assertEqual(expected, actual, msg)
321
322     def test_expression_diff_with_color(self) -> None:
323         source, _ = read_data("simple_cases", "expression.py")
324         expected, _ = read_data("simple_cases", "expression.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         try:
327             result = BlackRunner().invoke(
328                 black.main,
329                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
330             )
331         finally:
332             os.unlink(tmp_file)
333         actual = result.output
334         # We check the contents of the diff in `test_expression_diff`. All
335         # we need to check here is that color codes exist in the result.
336         self.assertIn("\033[1m", actual)
337         self.assertIn("\033[36m", actual)
338         self.assertIn("\033[32m", actual)
339         self.assertIn("\033[31m", actual)
340         self.assertIn("\033[0m", actual)
341
342     def test_detect_pos_only_arguments(self) -> None:
343         source, _ = read_data("py_38", "pep_570")
344         root = black.lib2to3_parse(source)
345         features = black.get_features_used(root)
346         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
347         versions = black.detect_target_versions(root)
348         self.assertIn(black.TargetVersion.PY38, versions)
349
350     def test_detect_debug_f_strings(self) -> None:
351         root = black.lib2to3_parse("""f"{x=}" """)
352         features = black.get_features_used(root)
353         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
354         versions = black.detect_target_versions(root)
355         self.assertIn(black.TargetVersion.PY38, versions)
356
357         root = black.lib2to3_parse(
358             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
359         )
360         features = black.get_features_used(root)
361         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
362
363         # We don't yet support feature version detection in nested f-strings
364         root = black.lib2to3_parse(
365             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
366         )
367         features = black.get_features_used(root)
368         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
369
370     @patch("black.dump_to_file", dump_to_stderr)
371     def test_string_quotes(self) -> None:
372         source, expected = read_data("miscellaneous", "string_quotes")
373         mode = black.Mode(preview=True)
374         assert_format(source, expected, mode)
375         mode = replace(mode, string_normalization=False)
376         not_normalized = fs(source, mode=mode)
377         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
378         black.assert_equivalent(source, not_normalized)
379         black.assert_stable(source, not_normalized, mode=mode)
380
381     def test_skip_source_first_line(self) -> None:
382         source, _ = read_data("miscellaneous", "invalid_header")
383         tmp_file = Path(black.dump_to_file(source))
384         # Full source should fail (invalid syntax at header)
385         self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
386         # So, skipping the first line should work
387         result = BlackRunner().invoke(
388             black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
389         )
390         self.assertEqual(result.exit_code, 0)
391         actual = tmp_file.read_text(encoding="utf-8")
392         self.assertFormatEqual(source, actual)
393
394     def test_skip_source_first_line_when_mixing_newlines(self) -> None:
395         code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
396         expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
397         with TemporaryDirectory() as workspace:
398             test_file = Path(workspace) / "skip_header.py"
399             test_file.write_bytes(code_mixing_newlines)
400             mode = replace(DEFAULT_MODE, skip_source_first_line=True)
401             ff(test_file, mode=mode, write_back=black.WriteBack.YES)
402             self.assertEqual(test_file.read_bytes(), expected)
403
404     def test_skip_magic_trailing_comma(self) -> None:
405         source, _ = read_data("simple_cases", "expression")
406         expected, _ = read_data(
407             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
408         )
409         tmp_file = Path(black.dump_to_file(source))
410         diff_header = re.compile(
411             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
412             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
413         )
414         try:
415             result = BlackRunner().invoke(
416                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
417             )
418             self.assertEqual(result.exit_code, 0)
419         finally:
420             os.unlink(tmp_file)
421         actual = result.output
422         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
423         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
424         if expected != actual:
425             dump = black.dump_to_file(actual)
426             msg = (
427                 "Expected diff isn't equal to the actual. If you made changes to"
428                 " expression.py and this is an anticipated difference, overwrite"
429                 " tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff"
430                 f" with {dump}"
431             )
432             self.assertEqual(expected, actual, msg)
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_async_as_identifier(self) -> None:
436         source_path = get_case_path("miscellaneous", "async_as_identifier")
437         source, expected = read_data_from_file(source_path)
438         actual = fs(source)
439         self.assertFormatEqual(expected, actual)
440         major, minor = sys.version_info[:2]
441         if major < 3 or (major <= 3 and minor < 7):
442             black.assert_equivalent(source, actual)
443         black.assert_stable(source, actual, DEFAULT_MODE)
444         # ensure black can parse this when the target is 3.6
445         self.invokeBlack([str(source_path), "--target-version", "py36"])
446         # but not on 3.7, because async/await is no longer an identifier
447         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_python37(self) -> None:
451         source_path = get_case_path("py_37", "python37")
452         source, expected = read_data_from_file(source_path)
453         actual = fs(source)
454         self.assertFormatEqual(expected, actual)
455         major, minor = sys.version_info[:2]
456         if major > 3 or (major == 3 and minor >= 7):
457             black.assert_equivalent(source, actual)
458         black.assert_stable(source, actual, DEFAULT_MODE)
459         # ensure black can parse this when the target is 3.7
460         self.invokeBlack([str(source_path), "--target-version", "py37"])
461         # but not on 3.6, because we use async as a reserved keyword
462         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
463
464     def test_tab_comment_indentation(self) -> None:
465         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
466         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
467         self.assertFormatEqual(contents_spc, fs(contents_spc))
468         self.assertFormatEqual(contents_spc, fs(contents_tab))
469
470         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
471         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
472         self.assertFormatEqual(contents_spc, fs(contents_spc))
473         self.assertFormatEqual(contents_spc, fs(contents_tab))
474
475         # mixed tabs and spaces (valid Python 2 code)
476         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
477         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
478         self.assertFormatEqual(contents_spc, fs(contents_spc))
479         self.assertFormatEqual(contents_spc, fs(contents_tab))
480
481         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
482         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
483         self.assertFormatEqual(contents_spc, fs(contents_spc))
484         self.assertFormatEqual(contents_spc, fs(contents_tab))
485
486     def test_false_positive_symlink_output_issue_3384(self) -> None:
487         # Emulate the behavior when using the CLI (`black ./child  --verbose`), which
488         # involves patching some `pathlib.Path` methods. In particular, `is_dir` is
489         # patched only on its first call: when checking if "./child" is a directory it
490         # should return True. The "./child" folder exists relative to the cwd when
491         # running from CLI, but fails when running the tests because cwd is different
492         project_root = Path(THIS_DIR / "data" / "nested_gitignore_tests")
493         working_directory = project_root / "root"
494         target_abspath = working_directory / "child"
495         target_contents = (
496             src.relative_to(working_directory) for src in target_abspath.iterdir()
497         )
498
499         def mock_n_calls(responses: List[bool]) -> Callable[[], bool]:
500             def _mocked_calls() -> bool:
501                 if responses:
502                     return responses.pop(0)
503                 return False
504
505             return _mocked_calls
506
507         with patch("pathlib.Path.iterdir", return_value=target_contents), patch(
508             "pathlib.Path.cwd", return_value=working_directory
509         ), patch("pathlib.Path.is_dir", side_effect=mock_n_calls([True])):
510             ctx = FakeContext()
511             ctx.obj["root"] = project_root
512             report = MagicMock(verbose=True)
513             black.get_sources(
514                 ctx=ctx,
515                 src=("./child",),
516                 quiet=False,
517                 verbose=True,
518                 include=DEFAULT_INCLUDE,
519                 exclude=None,
520                 report=report,
521                 extend_exclude=None,
522                 force_exclude=None,
523                 stdin_filename=None,
524             )
525         assert not any(
526             mock_args[1].startswith("is a symbolic link that points outside")
527             for _, mock_args, _ in report.path_ignored.mock_calls
528         ), "A symbolic link was reported."
529         report.path_ignored.assert_called_once_with(
530             Path("child", "b.py"), "matches a .gitignore file content"
531         )
532
533     def test_report_verbose(self) -> None:
534         report = Report(verbose=True)
535         out_lines = []
536         err_lines = []
537
538         def out(msg: str, **kwargs: Any) -> None:
539             out_lines.append(msg)
540
541         def err(msg: str, **kwargs: Any) -> None:
542             err_lines.append(msg)
543
544         with patch("black.output._out", out), patch("black.output._err", err):
545             report.done(Path("f1"), black.Changed.NO)
546             self.assertEqual(len(out_lines), 1)
547             self.assertEqual(len(err_lines), 0)
548             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
549             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
550             self.assertEqual(report.return_code, 0)
551             report.done(Path("f2"), black.Changed.YES)
552             self.assertEqual(len(out_lines), 2)
553             self.assertEqual(len(err_lines), 0)
554             self.assertEqual(out_lines[-1], "reformatted f2")
555             self.assertEqual(
556                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
557             )
558             report.done(Path("f3"), black.Changed.CACHED)
559             self.assertEqual(len(out_lines), 3)
560             self.assertEqual(len(err_lines), 0)
561             self.assertEqual(
562                 out_lines[-1], "f3 wasn't modified on disk since last run."
563             )
564             self.assertEqual(
565                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
566             )
567             self.assertEqual(report.return_code, 0)
568             report.check = True
569             self.assertEqual(report.return_code, 1)
570             report.check = False
571             report.failed(Path("e1"), "boom")
572             self.assertEqual(len(out_lines), 3)
573             self.assertEqual(len(err_lines), 1)
574             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
575             self.assertEqual(
576                 unstyle(str(report)),
577                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
578                 " reformat.",
579             )
580             self.assertEqual(report.return_code, 123)
581             report.done(Path("f3"), black.Changed.YES)
582             self.assertEqual(len(out_lines), 4)
583             self.assertEqual(len(err_lines), 1)
584             self.assertEqual(out_lines[-1], "reformatted f3")
585             self.assertEqual(
586                 unstyle(str(report)),
587                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
588                 " reformat.",
589             )
590             self.assertEqual(report.return_code, 123)
591             report.failed(Path("e2"), "boom")
592             self.assertEqual(len(out_lines), 4)
593             self.assertEqual(len(err_lines), 2)
594             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.path_ignored(Path("wat"), "no match")
602             self.assertEqual(len(out_lines), 5)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(out_lines[-1], "wat ignored: no match")
605             self.assertEqual(
606                 unstyle(str(report)),
607                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
608                 " reformat.",
609             )
610             self.assertEqual(report.return_code, 123)
611             report.done(Path("f4"), black.Changed.NO)
612             self.assertEqual(len(out_lines), 6)
613             self.assertEqual(len(err_lines), 2)
614             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
615             self.assertEqual(
616                 unstyle(str(report)),
617                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
618                 " reformat.",
619             )
620             self.assertEqual(report.return_code, 123)
621             report.check = True
622             self.assertEqual(
623                 unstyle(str(report)),
624                 "2 files would be reformatted, 3 files would be left unchanged, 2"
625                 " files would fail to reformat.",
626             )
627             report.check = False
628             report.diff = True
629             self.assertEqual(
630                 unstyle(str(report)),
631                 "2 files would be reformatted, 3 files would be left unchanged, 2"
632                 " files would fail to reformat.",
633             )
634
635     def test_report_quiet(self) -> None:
636         report = Report(quiet=True)
637         out_lines = []
638         err_lines = []
639
640         def out(msg: str, **kwargs: Any) -> None:
641             out_lines.append(msg)
642
643         def err(msg: str, **kwargs: Any) -> None:
644             err_lines.append(msg)
645
646         with patch("black.output._out", out), patch("black.output._err", err):
647             report.done(Path("f1"), black.Changed.NO)
648             self.assertEqual(len(out_lines), 0)
649             self.assertEqual(len(err_lines), 0)
650             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
651             self.assertEqual(report.return_code, 0)
652             report.done(Path("f2"), black.Changed.YES)
653             self.assertEqual(len(out_lines), 0)
654             self.assertEqual(len(err_lines), 0)
655             self.assertEqual(
656                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
657             )
658             report.done(Path("f3"), black.Changed.CACHED)
659             self.assertEqual(len(out_lines), 0)
660             self.assertEqual(len(err_lines), 0)
661             self.assertEqual(
662                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
663             )
664             self.assertEqual(report.return_code, 0)
665             report.check = True
666             self.assertEqual(report.return_code, 1)
667             report.check = False
668             report.failed(Path("e1"), "boom")
669             self.assertEqual(len(out_lines), 0)
670             self.assertEqual(len(err_lines), 1)
671             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
672             self.assertEqual(
673                 unstyle(str(report)),
674                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
675                 " reformat.",
676             )
677             self.assertEqual(report.return_code, 123)
678             report.done(Path("f3"), black.Changed.YES)
679             self.assertEqual(len(out_lines), 0)
680             self.assertEqual(len(err_lines), 1)
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
684                 " reformat.",
685             )
686             self.assertEqual(report.return_code, 123)
687             report.failed(Path("e2"), "boom")
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 2)
690             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
691             self.assertEqual(
692                 unstyle(str(report)),
693                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
694                 " reformat.",
695             )
696             self.assertEqual(report.return_code, 123)
697             report.path_ignored(Path("wat"), "no match")
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 2)
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
703                 " reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.done(Path("f4"), black.Changed.NO)
707             self.assertEqual(len(out_lines), 0)
708             self.assertEqual(len(err_lines), 2)
709             self.assertEqual(
710                 unstyle(str(report)),
711                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
712                 " reformat.",
713             )
714             self.assertEqual(report.return_code, 123)
715             report.check = True
716             self.assertEqual(
717                 unstyle(str(report)),
718                 "2 files would be reformatted, 3 files would be left unchanged, 2"
719                 " files would fail to reformat.",
720             )
721             report.check = False
722             report.diff = True
723             self.assertEqual(
724                 unstyle(str(report)),
725                 "2 files would be reformatted, 3 files would be left unchanged, 2"
726                 " files would fail to reformat.",
727             )
728
729     def test_report_normal(self) -> None:
730         report = black.Report()
731         out_lines = []
732         err_lines = []
733
734         def out(msg: str, **kwargs: Any) -> None:
735             out_lines.append(msg)
736
737         def err(msg: str, **kwargs: Any) -> None:
738             err_lines.append(msg)
739
740         with patch("black.output._out", out), patch("black.output._err", err):
741             report.done(Path("f1"), black.Changed.NO)
742             self.assertEqual(len(out_lines), 0)
743             self.assertEqual(len(err_lines), 0)
744             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
745             self.assertEqual(report.return_code, 0)
746             report.done(Path("f2"), black.Changed.YES)
747             self.assertEqual(len(out_lines), 1)
748             self.assertEqual(len(err_lines), 0)
749             self.assertEqual(out_lines[-1], "reformatted f2")
750             self.assertEqual(
751                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
752             )
753             report.done(Path("f3"), black.Changed.CACHED)
754             self.assertEqual(len(out_lines), 1)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(out_lines[-1], "reformatted f2")
757             self.assertEqual(
758                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
759             )
760             self.assertEqual(report.return_code, 0)
761             report.check = True
762             self.assertEqual(report.return_code, 1)
763             report.check = False
764             report.failed(Path("e1"), "boom")
765             self.assertEqual(len(out_lines), 1)
766             self.assertEqual(len(err_lines), 1)
767             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
768             self.assertEqual(
769                 unstyle(str(report)),
770                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
771                 " reformat.",
772             )
773             self.assertEqual(report.return_code, 123)
774             report.done(Path("f3"), black.Changed.YES)
775             self.assertEqual(len(out_lines), 2)
776             self.assertEqual(len(err_lines), 1)
777             self.assertEqual(out_lines[-1], "reformatted f3")
778             self.assertEqual(
779                 unstyle(str(report)),
780                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
781                 " reformat.",
782             )
783             self.assertEqual(report.return_code, 123)
784             report.failed(Path("e2"), "boom")
785             self.assertEqual(len(out_lines), 2)
786             self.assertEqual(len(err_lines), 2)
787             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
788             self.assertEqual(
789                 unstyle(str(report)),
790                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
791                 " reformat.",
792             )
793             self.assertEqual(report.return_code, 123)
794             report.path_ignored(Path("wat"), "no match")
795             self.assertEqual(len(out_lines), 2)
796             self.assertEqual(len(err_lines), 2)
797             self.assertEqual(
798                 unstyle(str(report)),
799                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
800                 " reformat.",
801             )
802             self.assertEqual(report.return_code, 123)
803             report.done(Path("f4"), black.Changed.NO)
804             self.assertEqual(len(out_lines), 2)
805             self.assertEqual(len(err_lines), 2)
806             self.assertEqual(
807                 unstyle(str(report)),
808                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
809                 " reformat.",
810             )
811             self.assertEqual(report.return_code, 123)
812             report.check = True
813             self.assertEqual(
814                 unstyle(str(report)),
815                 "2 files would be reformatted, 3 files would be left unchanged, 2"
816                 " files would fail to reformat.",
817             )
818             report.check = False
819             report.diff = True
820             self.assertEqual(
821                 unstyle(str(report)),
822                 "2 files would be reformatted, 3 files would be left unchanged, 2"
823                 " files would fail to reformat.",
824             )
825
826     def test_lib2to3_parse(self) -> None:
827         with self.assertRaises(black.InvalidInput):
828             black.lib2to3_parse("invalid syntax")
829
830         straddling = "x + y"
831         black.lib2to3_parse(straddling)
832         black.lib2to3_parse(straddling, {TargetVersion.PY36})
833
834         py2_only = "print x"
835         with self.assertRaises(black.InvalidInput):
836             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
837
838         py3_only = "exec(x, end=y)"
839         black.lib2to3_parse(py3_only)
840         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
841
842     def test_get_features_used_decorator(self) -> None:
843         # Test the feature detection of new decorator syntax
844         # since this makes some test cases of test_get_features_used()
845         # fails if it fails, this is tested first so that a useful case
846         # is identified
847         simples, relaxed = read_data("miscellaneous", "decorators")
848         # skip explanation comments at the top of the file
849         for simple_test in simples.split("##")[1:]:
850             node = black.lib2to3_parse(simple_test)
851             decorator = str(node.children[0].children[0]).strip()
852             self.assertNotIn(
853                 Feature.RELAXED_DECORATORS,
854                 black.get_features_used(node),
855                 msg=(
856                     f"decorator '{decorator}' follows python<=3.8 syntax"
857                     "but is detected as 3.9+"
858                     # f"The full node is\n{node!r}"
859                 ),
860             )
861         # skip the '# output' comment at the top of the output part
862         for relaxed_test in relaxed.split("##")[1:]:
863             node = black.lib2to3_parse(relaxed_test)
864             decorator = str(node.children[0].children[0]).strip()
865             self.assertIn(
866                 Feature.RELAXED_DECORATORS,
867                 black.get_features_used(node),
868                 msg=(
869                     f"decorator '{decorator}' uses python3.9+ syntax"
870                     "but is detected as python<=3.8"
871                     # f"The full node is\n{node!r}"
872                 ),
873             )
874
875     def test_get_features_used(self) -> None:
876         node = black.lib2to3_parse("def f(*, arg): ...\n")
877         self.assertEqual(black.get_features_used(node), set())
878         node = black.lib2to3_parse("def f(*, arg,): ...\n")
879         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
880         node = black.lib2to3_parse("f(*arg,)\n")
881         self.assertEqual(
882             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
883         )
884         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
885         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
886         node = black.lib2to3_parse("123_456\n")
887         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
888         node = black.lib2to3_parse("123456\n")
889         self.assertEqual(black.get_features_used(node), set())
890         source, expected = read_data("simple_cases", "function")
891         node = black.lib2to3_parse(source)
892         expected_features = {
893             Feature.TRAILING_COMMA_IN_CALL,
894             Feature.TRAILING_COMMA_IN_DEF,
895             Feature.F_STRINGS,
896         }
897         self.assertEqual(black.get_features_used(node), expected_features)
898         node = black.lib2to3_parse(expected)
899         self.assertEqual(black.get_features_used(node), expected_features)
900         source, expected = read_data("simple_cases", "expression")
901         node = black.lib2to3_parse(source)
902         self.assertEqual(black.get_features_used(node), set())
903         node = black.lib2to3_parse(expected)
904         self.assertEqual(black.get_features_used(node), set())
905         node = black.lib2to3_parse("lambda a, /, b: ...")
906         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
907         node = black.lib2to3_parse("def fn(a, /, b): ...")
908         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
909         node = black.lib2to3_parse("def fn(): yield a, b")
910         self.assertEqual(black.get_features_used(node), set())
911         node = black.lib2to3_parse("def fn(): return a, b")
912         self.assertEqual(black.get_features_used(node), set())
913         node = black.lib2to3_parse("def fn(): yield *b, c")
914         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
915         node = black.lib2to3_parse("def fn(): return a, *b, c")
916         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
917         node = black.lib2to3_parse("x = a, *b, c")
918         self.assertEqual(black.get_features_used(node), set())
919         node = black.lib2to3_parse("x: Any = regular")
920         self.assertEqual(black.get_features_used(node), set())
921         node = black.lib2to3_parse("x: Any = (regular, regular)")
922         self.assertEqual(black.get_features_used(node), set())
923         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
924         self.assertEqual(black.get_features_used(node), set())
925         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
926         self.assertEqual(
927             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
928         )
929         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
930         self.assertEqual(black.get_features_used(node), set())
931         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
932         self.assertEqual(black.get_features_used(node), set())
933         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
934         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
935         node = black.lib2to3_parse("a[*b]")
936         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
937         node = black.lib2to3_parse("a[x, *y(), z] = t")
938         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
939         node = black.lib2to3_parse("def fn(*args: *T): pass")
940         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
941
942     def test_get_features_used_for_future_flags(self) -> None:
943         for src, features in [
944             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
945             (
946                 "from __future__ import (other, annotations)",
947                 {Feature.FUTURE_ANNOTATIONS},
948             ),
949             ("a = 1 + 2\nfrom something import annotations", set()),
950             ("from __future__ import x, y", set()),
951         ]:
952             with self.subTest(src=src, features=features):
953                 node = black.lib2to3_parse(src)
954                 future_imports = black.get_future_imports(node)
955                 self.assertEqual(
956                     black.get_features_used(node, future_imports=future_imports),
957                     features,
958                 )
959
960     def test_get_future_imports(self) -> None:
961         node = black.lib2to3_parse("\n")
962         self.assertEqual(set(), black.get_future_imports(node))
963         node = black.lib2to3_parse("from __future__ import black\n")
964         self.assertEqual({"black"}, black.get_future_imports(node))
965         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
966         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
968         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
969         node = black.lib2to3_parse(
970             "from __future__ import multiple\nfrom __future__ import imports\n"
971         )
972         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
973         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
974         self.assertEqual({"black"}, black.get_future_imports(node))
975         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
976         self.assertEqual({"black"}, black.get_future_imports(node))
977         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
978         self.assertEqual(set(), black.get_future_imports(node))
979         node = black.lib2to3_parse("from some.module import black\n")
980         self.assertEqual(set(), black.get_future_imports(node))
981         node = black.lib2to3_parse(
982             "from __future__ import unicode_literals as _unicode_literals"
983         )
984         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
985         node = black.lib2to3_parse(
986             "from __future__ import unicode_literals as _lol, print"
987         )
988         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
989
990     @pytest.mark.incompatible_with_mypyc
991     def test_debug_visitor(self) -> None:
992         source, _ = read_data("miscellaneous", "debug_visitor")
993         expected, _ = read_data("miscellaneous", "debug_visitor.out")
994         out_lines = []
995         err_lines = []
996
997         def out(msg: str, **kwargs: Any) -> None:
998             out_lines.append(msg)
999
1000         def err(msg: str, **kwargs: Any) -> None:
1001             err_lines.append(msg)
1002
1003         with patch("black.debug.out", out):
1004             DebugVisitor.show(source)
1005         actual = "\n".join(out_lines) + "\n"
1006         log_name = ""
1007         if expected != actual:
1008             log_name = black.dump_to_file(*out_lines)
1009         self.assertEqual(
1010             expected,
1011             actual,
1012             f"AST print out is different. Actual version dumped to {log_name}",
1013         )
1014
1015     def test_format_file_contents(self) -> None:
1016         mode = DEFAULT_MODE
1017         empty = ""
1018         with self.assertRaises(black.NothingChanged):
1019             black.format_file_contents(empty, mode=mode, fast=False)
1020         just_nl = "\n"
1021         with self.assertRaises(black.NothingChanged):
1022             black.format_file_contents(just_nl, mode=mode, fast=False)
1023         same = "j = [1, 2, 3]\n"
1024         with self.assertRaises(black.NothingChanged):
1025             black.format_file_contents(same, mode=mode, fast=False)
1026         different = "j = [1,2,3]"
1027         expected = same
1028         actual = black.format_file_contents(different, mode=mode, fast=False)
1029         self.assertEqual(expected, actual)
1030         invalid = "return if you can"
1031         with self.assertRaises(black.InvalidInput) as e:
1032             black.format_file_contents(invalid, mode=mode, fast=False)
1033         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1034
1035         mode = black.Mode(preview=True)
1036         just_crlf = "\r\n"
1037         with self.assertRaises(black.NothingChanged):
1038             black.format_file_contents(just_crlf, mode=mode, fast=False)
1039         just_whitespace_nl = "\n\t\n \n\t \n \t\n\n"
1040         actual = black.format_file_contents(just_whitespace_nl, mode=mode, fast=False)
1041         self.assertEqual("\n", actual)
1042         just_whitespace_crlf = "\r\n\t\r\n \r\n\t \r\n \t\r\n\r\n"
1043         actual = black.format_file_contents(just_whitespace_crlf, mode=mode, fast=False)
1044         self.assertEqual("\r\n", actual)
1045
1046     def test_endmarker(self) -> None:
1047         n = black.lib2to3_parse("\n")
1048         self.assertEqual(n.type, black.syms.file_input)
1049         self.assertEqual(len(n.children), 1)
1050         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1051
1052     @pytest.mark.incompatible_with_mypyc
1053     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1054     def test_assertFormatEqual(self) -> None:
1055         out_lines = []
1056         err_lines = []
1057
1058         def out(msg: str, **kwargs: Any) -> None:
1059             out_lines.append(msg)
1060
1061         def err(msg: str, **kwargs: Any) -> None:
1062             err_lines.append(msg)
1063
1064         with patch("black.output._out", out), patch("black.output._err", err):
1065             with self.assertRaises(AssertionError):
1066                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1067
1068         out_str = "".join(out_lines)
1069         self.assertIn("Expected tree:", out_str)
1070         self.assertIn("Actual tree:", out_str)
1071         self.assertEqual("".join(err_lines), "")
1072
1073     @event_loop()
1074     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1075     def test_works_in_mono_process_only_environment(self) -> None:
1076         with cache_dir() as workspace:
1077             for f in [
1078                 (workspace / "one.py").resolve(),
1079                 (workspace / "two.py").resolve(),
1080             ]:
1081                 f.write_text('print("hello")\n', encoding="utf-8")
1082             self.invokeBlack([str(workspace)])
1083
1084     @event_loop()
1085     def test_check_diff_use_together(self) -> None:
1086         with cache_dir():
1087             # Files which will be reformatted.
1088             src1 = get_case_path("miscellaneous", "string_quotes")
1089             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1090             # Files which will not be reformatted.
1091             src2 = get_case_path("simple_cases", "composition")
1092             self.invokeBlack([str(src2), "--diff", "--check"])
1093             # Multi file command.
1094             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1095
1096     def test_no_src_fails(self) -> None:
1097         with cache_dir():
1098             self.invokeBlack([], exit_code=1)
1099
1100     def test_src_and_code_fails(self) -> None:
1101         with cache_dir():
1102             self.invokeBlack([".", "-c", "0"], exit_code=1)
1103
1104     def test_broken_symlink(self) -> None:
1105         with cache_dir() as workspace:
1106             symlink = workspace / "broken_link.py"
1107             try:
1108                 symlink.symlink_to("nonexistent.py")
1109             except (OSError, NotImplementedError) as e:
1110                 self.skipTest(f"Can't create symlinks: {e}")
1111             self.invokeBlack([str(workspace.resolve())])
1112
1113     def test_single_file_force_pyi(self) -> None:
1114         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1115         contents, expected = read_data("miscellaneous", "force_pyi")
1116         with cache_dir() as workspace:
1117             path = (workspace / "file.py").resolve()
1118             path.write_text(contents, encoding="utf-8")
1119             self.invokeBlack([str(path), "--pyi"])
1120             actual = path.read_text(encoding="utf-8")
1121             # verify cache with --pyi is separate
1122             pyi_cache = black.read_cache(pyi_mode)
1123             self.assertIn(str(path), pyi_cache)
1124             normal_cache = black.read_cache(DEFAULT_MODE)
1125             self.assertNotIn(str(path), normal_cache)
1126         self.assertFormatEqual(expected, actual)
1127         black.assert_equivalent(contents, actual)
1128         black.assert_stable(contents, actual, pyi_mode)
1129
1130     @event_loop()
1131     def test_multi_file_force_pyi(self) -> None:
1132         reg_mode = DEFAULT_MODE
1133         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1134         contents, expected = read_data("miscellaneous", "force_pyi")
1135         with cache_dir() as workspace:
1136             paths = [
1137                 (workspace / "file1.py").resolve(),
1138                 (workspace / "file2.py").resolve(),
1139             ]
1140             for path in paths:
1141                 path.write_text(contents, encoding="utf-8")
1142             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1143             for path in paths:
1144                 actual = path.read_text(encoding="utf-8")
1145                 self.assertEqual(actual, expected)
1146             # verify cache with --pyi is separate
1147             pyi_cache = black.read_cache(pyi_mode)
1148             normal_cache = black.read_cache(reg_mode)
1149             for path in paths:
1150                 self.assertIn(str(path), pyi_cache)
1151                 self.assertNotIn(str(path), normal_cache)
1152
1153     def test_pipe_force_pyi(self) -> None:
1154         source, expected = read_data("miscellaneous", "force_pyi")
1155         result = CliRunner().invoke(
1156             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf-8"))
1157         )
1158         self.assertEqual(result.exit_code, 0)
1159         actual = result.output
1160         self.assertFormatEqual(actual, expected)
1161
1162     def test_single_file_force_py36(self) -> None:
1163         reg_mode = DEFAULT_MODE
1164         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1165         source, expected = read_data("miscellaneous", "force_py36")
1166         with cache_dir() as workspace:
1167             path = (workspace / "file.py").resolve()
1168             path.write_text(source, encoding="utf-8")
1169             self.invokeBlack([str(path), *PY36_ARGS])
1170             actual = path.read_text(encoding="utf-8")
1171             # verify cache with --target-version is separate
1172             py36_cache = black.read_cache(py36_mode)
1173             self.assertIn(str(path), py36_cache)
1174             normal_cache = black.read_cache(reg_mode)
1175             self.assertNotIn(str(path), normal_cache)
1176         self.assertEqual(actual, expected)
1177
1178     @event_loop()
1179     def test_multi_file_force_py36(self) -> None:
1180         reg_mode = DEFAULT_MODE
1181         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1182         source, expected = read_data("miscellaneous", "force_py36")
1183         with cache_dir() as workspace:
1184             paths = [
1185                 (workspace / "file1.py").resolve(),
1186                 (workspace / "file2.py").resolve(),
1187             ]
1188             for path in paths:
1189                 path.write_text(source, encoding="utf-8")
1190             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1191             for path in paths:
1192                 actual = path.read_text(encoding="utf-8")
1193                 self.assertEqual(actual, expected)
1194             # verify cache with --target-version is separate
1195             pyi_cache = black.read_cache(py36_mode)
1196             normal_cache = black.read_cache(reg_mode)
1197             for path in paths:
1198                 self.assertIn(str(path), pyi_cache)
1199                 self.assertNotIn(str(path), normal_cache)
1200
1201     def test_pipe_force_py36(self) -> None:
1202         source, expected = read_data("miscellaneous", "force_py36")
1203         result = CliRunner().invoke(
1204             black.main,
1205             ["-", "-q", "--target-version=py36"],
1206             input=BytesIO(source.encode("utf-8")),
1207         )
1208         self.assertEqual(result.exit_code, 0)
1209         actual = result.output
1210         self.assertFormatEqual(actual, expected)
1211
1212     @pytest.mark.incompatible_with_mypyc
1213     def test_reformat_one_with_stdin(self) -> None:
1214         with patch(
1215             "black.format_stdin_to_stdout",
1216             return_value=lambda *args, **kwargs: black.Changed.YES,
1217         ) as fsts:
1218             report = MagicMock()
1219             path = Path("-")
1220             black.reformat_one(
1221                 path,
1222                 fast=True,
1223                 write_back=black.WriteBack.YES,
1224                 mode=DEFAULT_MODE,
1225                 report=report,
1226             )
1227             fsts.assert_called_once()
1228             report.done.assert_called_with(path, black.Changed.YES)
1229
1230     @pytest.mark.incompatible_with_mypyc
1231     def test_reformat_one_with_stdin_filename(self) -> None:
1232         with patch(
1233             "black.format_stdin_to_stdout",
1234             return_value=lambda *args, **kwargs: black.Changed.YES,
1235         ) as fsts:
1236             report = MagicMock()
1237             p = "foo.py"
1238             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1239             expected = Path(p)
1240             black.reformat_one(
1241                 path,
1242                 fast=True,
1243                 write_back=black.WriteBack.YES,
1244                 mode=DEFAULT_MODE,
1245                 report=report,
1246             )
1247             fsts.assert_called_once_with(
1248                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1249             )
1250             # __BLACK_STDIN_FILENAME__ should have been stripped
1251             report.done.assert_called_with(expected, black.Changed.YES)
1252
1253     @pytest.mark.incompatible_with_mypyc
1254     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1255         with patch(
1256             "black.format_stdin_to_stdout",
1257             return_value=lambda *args, **kwargs: black.Changed.YES,
1258         ) as fsts:
1259             report = MagicMock()
1260             p = "foo.pyi"
1261             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1262             expected = Path(p)
1263             black.reformat_one(
1264                 path,
1265                 fast=True,
1266                 write_back=black.WriteBack.YES,
1267                 mode=DEFAULT_MODE,
1268                 report=report,
1269             )
1270             fsts.assert_called_once_with(
1271                 fast=True,
1272                 write_back=black.WriteBack.YES,
1273                 mode=replace(DEFAULT_MODE, is_pyi=True),
1274             )
1275             # __BLACK_STDIN_FILENAME__ should have been stripped
1276             report.done.assert_called_with(expected, black.Changed.YES)
1277
1278     @pytest.mark.incompatible_with_mypyc
1279     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1280         with patch(
1281             "black.format_stdin_to_stdout",
1282             return_value=lambda *args, **kwargs: black.Changed.YES,
1283         ) as fsts:
1284             report = MagicMock()
1285             p = "foo.ipynb"
1286             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1287             expected = Path(p)
1288             black.reformat_one(
1289                 path,
1290                 fast=True,
1291                 write_back=black.WriteBack.YES,
1292                 mode=DEFAULT_MODE,
1293                 report=report,
1294             )
1295             fsts.assert_called_once_with(
1296                 fast=True,
1297                 write_back=black.WriteBack.YES,
1298                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1299             )
1300             # __BLACK_STDIN_FILENAME__ should have been stripped
1301             report.done.assert_called_with(expected, black.Changed.YES)
1302
1303     @pytest.mark.incompatible_with_mypyc
1304     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1305         with patch(
1306             "black.format_stdin_to_stdout",
1307             return_value=lambda *args, **kwargs: black.Changed.YES,
1308         ) as fsts:
1309             report = MagicMock()
1310             # Even with an existing file, since we are forcing stdin, black
1311             # should output to stdout and not modify the file inplace
1312             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1313             # Make sure is_file actually returns True
1314             self.assertTrue(p.is_file())
1315             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1316             expected = Path(p)
1317             black.reformat_one(
1318                 path,
1319                 fast=True,
1320                 write_back=black.WriteBack.YES,
1321                 mode=DEFAULT_MODE,
1322                 report=report,
1323             )
1324             fsts.assert_called_once()
1325             # __BLACK_STDIN_FILENAME__ should have been stripped
1326             report.done.assert_called_with(expected, black.Changed.YES)
1327
1328     def test_reformat_one_with_stdin_empty(self) -> None:
1329         cases = [
1330             ("", ""),
1331             ("\n", "\n"),
1332             ("\r\n", "\r\n"),
1333             (" \t", ""),
1334             (" \t\n\t ", "\n"),
1335             (" \t\r\n\t ", "\r\n"),
1336         ]
1337
1338         def _new_wrapper(
1339             output: io.StringIO, io_TextIOWrapper: Type[io.TextIOWrapper]
1340         ) -> Callable[[Any, Any], io.TextIOWrapper]:
1341             def get_output(*args: Any, **kwargs: Any) -> io.TextIOWrapper:
1342                 if args == (sys.stdout.buffer,):
1343                     # It's `format_stdin_to_stdout()` calling `io.TextIOWrapper()`,
1344                     # return our mock object.
1345                     return output
1346                 # It's something else (i.e. `decode_bytes()`) calling
1347                 # `io.TextIOWrapper()`, pass through to the original implementation.
1348                 # See discussion in https://github.com/psf/black/pull/2489
1349                 return io_TextIOWrapper(*args, **kwargs)
1350
1351             return get_output
1352
1353         mode = black.Mode(preview=True)
1354         for content, expected in cases:
1355             output = io.StringIO()
1356             io_TextIOWrapper = io.TextIOWrapper
1357
1358             with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1359                 try:
1360                     black.format_stdin_to_stdout(
1361                         fast=True,
1362                         content=content,
1363                         write_back=black.WriteBack.YES,
1364                         mode=mode,
1365                     )
1366                 except io.UnsupportedOperation:
1367                     pass  # StringIO does not support detach
1368                 assert output.getvalue() == expected
1369
1370         # An empty string is the only test case for `preview=False`
1371         output = io.StringIO()
1372         io_TextIOWrapper = io.TextIOWrapper
1373         with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1374             try:
1375                 black.format_stdin_to_stdout(
1376                     fast=True,
1377                     content="",
1378                     write_back=black.WriteBack.YES,
1379                     mode=DEFAULT_MODE,
1380                 )
1381             except io.UnsupportedOperation:
1382                 pass  # StringIO does not support detach
1383             assert output.getvalue() == ""
1384
1385     def test_invalid_cli_regex(self) -> None:
1386         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1387             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1388
1389     def test_required_version_matches_version(self) -> None:
1390         self.invokeBlack(
1391             ["--required-version", black.__version__, "-c", "0"],
1392             exit_code=0,
1393             ignore_config=True,
1394         )
1395
1396     def test_required_version_matches_partial_version(self) -> None:
1397         self.invokeBlack(
1398             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1399             exit_code=0,
1400             ignore_config=True,
1401         )
1402
1403     def test_required_version_does_not_match_on_minor_version(self) -> None:
1404         self.invokeBlack(
1405             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1406             exit_code=1,
1407             ignore_config=True,
1408         )
1409
1410     def test_required_version_does_not_match_version(self) -> None:
1411         result = BlackRunner().invoke(
1412             black.main,
1413             ["--required-version", "20.99b", "-c", "0"],
1414         )
1415         self.assertEqual(result.exit_code, 1)
1416         self.assertIn("required version", result.stderr)
1417
1418     def test_preserves_line_endings(self) -> None:
1419         with TemporaryDirectory() as workspace:
1420             test_file = Path(workspace) / "test.py"
1421             for nl in ["\n", "\r\n"]:
1422                 contents = nl.join(["def f(  ):", "    pass"])
1423                 test_file.write_bytes(contents.encode())
1424                 ff(test_file, write_back=black.WriteBack.YES)
1425                 updated_contents: bytes = test_file.read_bytes()
1426                 self.assertIn(nl.encode(), updated_contents)
1427                 if nl == "\n":
1428                     self.assertNotIn(b"\r\n", updated_contents)
1429
1430     def test_preserves_line_endings_via_stdin(self) -> None:
1431         for nl in ["\n", "\r\n"]:
1432             contents = nl.join(["def f(  ):", "    pass"])
1433             runner = BlackRunner()
1434             result = runner.invoke(
1435                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf-8"))
1436             )
1437             self.assertEqual(result.exit_code, 0)
1438             output = result.stdout_bytes
1439             self.assertIn(nl.encode("utf-8"), output)
1440             if nl == "\n":
1441                 self.assertNotIn(b"\r\n", output)
1442
1443     def test_normalize_line_endings(self) -> None:
1444         with TemporaryDirectory() as workspace:
1445             test_file = Path(workspace) / "test.py"
1446             for data, expected in (
1447                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1448                 (b"l\nl\r\n ", b"l\nl\n"),
1449             ):
1450                 test_file.write_bytes(data)
1451                 ff(test_file, write_back=black.WriteBack.YES)
1452                 self.assertEqual(test_file.read_bytes(), expected)
1453
1454     def test_assert_equivalent_different_asts(self) -> None:
1455         with self.assertRaises(AssertionError):
1456             black.assert_equivalent("{}", "None")
1457
1458     def test_shhh_click(self) -> None:
1459         try:
1460             from click import _unicodefun  # type: ignore
1461         except ImportError:
1462             self.skipTest("Incompatible Click version")
1463
1464         if not hasattr(_unicodefun, "_verify_python_env"):
1465             self.skipTest("Incompatible Click version")
1466
1467         # First, let's see if Click is crashing with a preferred ASCII charset.
1468         with patch("locale.getpreferredencoding") as gpe:
1469             gpe.return_value = "ASCII"
1470             with self.assertRaises(RuntimeError):
1471                 _unicodefun._verify_python_env()
1472         # Now, let's silence Click...
1473         black.patch_click()
1474         # ...and confirm it's silent.
1475         with patch("locale.getpreferredencoding") as gpe:
1476             gpe.return_value = "ASCII"
1477             try:
1478                 _unicodefun._verify_python_env()
1479             except RuntimeError as re:
1480                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1481
1482     def test_root_logger_not_used_directly(self) -> None:
1483         def fail(*args: Any, **kwargs: Any) -> None:
1484             self.fail("Record created with root logger")
1485
1486         with patch.multiple(
1487             logging.root,
1488             debug=fail,
1489             info=fail,
1490             warning=fail,
1491             error=fail,
1492             critical=fail,
1493             log=fail,
1494         ):
1495             ff(THIS_DIR / "util.py")
1496
1497     def test_invalid_config_return_code(self) -> None:
1498         tmp_file = Path(black.dump_to_file())
1499         try:
1500             tmp_config = Path(black.dump_to_file())
1501             tmp_config.unlink()
1502             args = ["--config", str(tmp_config), str(tmp_file)]
1503             self.invokeBlack(args, exit_code=2, ignore_config=False)
1504         finally:
1505             tmp_file.unlink()
1506
1507     def test_parse_pyproject_toml(self) -> None:
1508         test_toml_file = THIS_DIR / "test.toml"
1509         config = black.parse_pyproject_toml(str(test_toml_file))
1510         self.assertEqual(config["verbose"], 1)
1511         self.assertEqual(config["check"], "no")
1512         self.assertEqual(config["diff"], "y")
1513         self.assertEqual(config["color"], True)
1514         self.assertEqual(config["line_length"], 79)
1515         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1516         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1517         self.assertEqual(config["exclude"], r"\.pyi?$")
1518         self.assertEqual(config["include"], r"\.py?$")
1519
1520     def test_parse_pyproject_toml_project_metadata(self) -> None:
1521         for test_toml, expected in [
1522             ("only_black_pyproject.toml", ["py310"]),
1523             ("only_metadata_pyproject.toml", ["py37", "py38", "py39", "py310"]),
1524             ("neither_pyproject.toml", None),
1525             ("both_pyproject.toml", ["py310"]),
1526         ]:
1527             test_toml_file = THIS_DIR / "data" / "project_metadata" / test_toml
1528             config = black.parse_pyproject_toml(str(test_toml_file))
1529             self.assertEqual(config.get("target_version"), expected)
1530
1531     def test_infer_target_version(self) -> None:
1532         for version, expected in [
1533             ("3.6", [TargetVersion.PY36]),
1534             ("3.11.0rc1", [TargetVersion.PY311]),
1535             (">=3.10", [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312]),
1536             (
1537                 ">=3.10.6",
1538                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1539             ),
1540             ("<3.6", [TargetVersion.PY33, TargetVersion.PY34, TargetVersion.PY35]),
1541             (">3.7,<3.10", [TargetVersion.PY38, TargetVersion.PY39]),
1542             (
1543                 ">3.7,!=3.8,!=3.9",
1544                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1545             ),
1546             (
1547                 "> 3.9.4, != 3.10.3",
1548                 [
1549                     TargetVersion.PY39,
1550                     TargetVersion.PY310,
1551                     TargetVersion.PY311,
1552                     TargetVersion.PY312,
1553                 ],
1554             ),
1555             (
1556                 "!=3.3,!=3.4",
1557                 [
1558                     TargetVersion.PY35,
1559                     TargetVersion.PY36,
1560                     TargetVersion.PY37,
1561                     TargetVersion.PY38,
1562                     TargetVersion.PY39,
1563                     TargetVersion.PY310,
1564                     TargetVersion.PY311,
1565                     TargetVersion.PY312,
1566                 ],
1567             ),
1568             (
1569                 "==3.*",
1570                 [
1571                     TargetVersion.PY33,
1572                     TargetVersion.PY34,
1573                     TargetVersion.PY35,
1574                     TargetVersion.PY36,
1575                     TargetVersion.PY37,
1576                     TargetVersion.PY38,
1577                     TargetVersion.PY39,
1578                     TargetVersion.PY310,
1579                     TargetVersion.PY311,
1580                     TargetVersion.PY312,
1581                 ],
1582             ),
1583             ("==3.8.*", [TargetVersion.PY38]),
1584             (None, None),
1585             ("", None),
1586             ("invalid", None),
1587             ("==invalid", None),
1588             (">3.9,!=invalid", None),
1589             ("3", None),
1590             ("3.2", None),
1591             ("2.7.18", None),
1592             ("==2.7", None),
1593             (">3.10,<3.11", None),
1594         ]:
1595             test_toml = {"project": {"requires-python": version}}
1596             result = black.files.infer_target_version(test_toml)
1597             self.assertEqual(result, expected)
1598
1599     def test_read_pyproject_toml(self) -> None:
1600         test_toml_file = THIS_DIR / "test.toml"
1601         fake_ctx = FakeContext()
1602         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1603         config = fake_ctx.default_map
1604         self.assertEqual(config["verbose"], "1")
1605         self.assertEqual(config["check"], "no")
1606         self.assertEqual(config["diff"], "y")
1607         self.assertEqual(config["color"], "True")
1608         self.assertEqual(config["line_length"], "79")
1609         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1610         self.assertEqual(config["exclude"], r"\.pyi?$")
1611         self.assertEqual(config["include"], r"\.py?$")
1612
1613     def test_read_pyproject_toml_from_stdin(self) -> None:
1614         with TemporaryDirectory() as workspace:
1615             root = Path(workspace)
1616
1617             src_dir = root / "src"
1618             src_dir.mkdir()
1619
1620             src_pyproject = src_dir / "pyproject.toml"
1621             src_pyproject.touch()
1622
1623             test_toml_content = (THIS_DIR / "test.toml").read_text(encoding="utf-8")
1624             src_pyproject.write_text(test_toml_content, encoding="utf-8")
1625
1626             src_python = src_dir / "foo.py"
1627             src_python.touch()
1628
1629             fake_ctx = FakeContext()
1630             fake_ctx.params["src"] = ("-",)
1631             fake_ctx.params["stdin_filename"] = str(src_python)
1632
1633             with change_directory(root):
1634                 black.read_pyproject_toml(fake_ctx, FakeParameter(), None)
1635
1636             config = fake_ctx.default_map
1637             self.assertEqual(config["verbose"], "1")
1638             self.assertEqual(config["check"], "no")
1639             self.assertEqual(config["diff"], "y")
1640             self.assertEqual(config["color"], "True")
1641             self.assertEqual(config["line_length"], "79")
1642             self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1643             self.assertEqual(config["exclude"], r"\.pyi?$")
1644             self.assertEqual(config["include"], r"\.py?$")
1645
1646     @pytest.mark.incompatible_with_mypyc
1647     def test_find_project_root(self) -> None:
1648         with TemporaryDirectory() as workspace:
1649             root = Path(workspace)
1650             test_dir = root / "test"
1651             test_dir.mkdir()
1652
1653             src_dir = root / "src"
1654             src_dir.mkdir()
1655
1656             root_pyproject = root / "pyproject.toml"
1657             root_pyproject.touch()
1658             src_pyproject = src_dir / "pyproject.toml"
1659             src_pyproject.touch()
1660             src_python = src_dir / "foo.py"
1661             src_python.touch()
1662
1663             self.assertEqual(
1664                 black.find_project_root((src_dir, test_dir)),
1665                 (root.resolve(), "pyproject.toml"),
1666             )
1667             self.assertEqual(
1668                 black.find_project_root((src_dir,)),
1669                 (src_dir.resolve(), "pyproject.toml"),
1670             )
1671             self.assertEqual(
1672                 black.find_project_root((src_python,)),
1673                 (src_dir.resolve(), "pyproject.toml"),
1674             )
1675
1676             with change_directory(test_dir):
1677                 self.assertEqual(
1678                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1679                     (src_dir.resolve(), "pyproject.toml"),
1680                 )
1681
1682     @patch(
1683         "black.files.find_user_pyproject_toml",
1684     )
1685     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1686         find_user_pyproject_toml.side_effect = RuntimeError()
1687
1688         with redirect_stderr(io.StringIO()) as stderr:
1689             result = black.files.find_pyproject_toml(
1690                 path_search_start=(str(Path.cwd().root),)
1691             )
1692
1693         assert result is None
1694         err = stderr.getvalue()
1695         assert "Ignoring user configuration" in err
1696
1697     @patch(
1698         "black.files.find_user_pyproject_toml",
1699         black.files.find_user_pyproject_toml.__wrapped__,
1700     )
1701     def test_find_user_pyproject_toml_linux(self) -> None:
1702         if system() == "Windows":
1703             return
1704
1705         # Test if XDG_CONFIG_HOME is checked
1706         with TemporaryDirectory() as workspace:
1707             tmp_user_config = Path(workspace) / "black"
1708             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1709                 self.assertEqual(
1710                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1711                 )
1712
1713         # Test fallback for XDG_CONFIG_HOME
1714         with patch.dict("os.environ"):
1715             os.environ.pop("XDG_CONFIG_HOME", None)
1716             fallback_user_config = Path("~/.config").expanduser() / "black"
1717             self.assertEqual(
1718                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1719             )
1720
1721     def test_find_user_pyproject_toml_windows(self) -> None:
1722         if system() != "Windows":
1723             return
1724
1725         user_config_path = Path.home() / ".black"
1726         self.assertEqual(
1727             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1728         )
1729
1730     def test_bpo_33660_workaround(self) -> None:
1731         if system() == "Windows":
1732             return
1733
1734         # https://bugs.python.org/issue33660
1735         root = Path("/")
1736         with change_directory(root):
1737             path = Path("workspace") / "project"
1738             report = black.Report(verbose=True)
1739             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1740             self.assertEqual(normalized_path, "workspace/project")
1741
1742     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1743         if system() != "Windows":
1744             return
1745
1746         with TemporaryDirectory() as workspace:
1747             root = Path(workspace)
1748             junction_dir = root / "junction"
1749             junction_target_outside_of_root = root / ".."
1750             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1751
1752             report = black.Report(verbose=True)
1753             normalized_path = black.normalize_path_maybe_ignore(
1754                 junction_dir, root, report
1755             )
1756             # Manually delete for Python < 3.8
1757             os.system(f"rmdir {junction_dir}")
1758
1759             self.assertEqual(normalized_path, None)
1760
1761     def test_newline_comment_interaction(self) -> None:
1762         source = "class A:\\\r\n# type: ignore\n pass\n"
1763         output = black.format_str(source, mode=DEFAULT_MODE)
1764         black.assert_stable(source, output, mode=DEFAULT_MODE)
1765
1766     def test_bpo_2142_workaround(self) -> None:
1767         # https://bugs.python.org/issue2142
1768
1769         source, _ = read_data("miscellaneous", "missing_final_newline")
1770         # read_data adds a trailing newline
1771         source = source.rstrip()
1772         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1773         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1774         diff_header = re.compile(
1775             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1776             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
1777         )
1778         try:
1779             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1780             self.assertEqual(result.exit_code, 0)
1781         finally:
1782             os.unlink(tmp_file)
1783         actual = result.output
1784         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1785         self.assertEqual(actual, expected)
1786
1787     @staticmethod
1788     def compare_results(
1789         result: click.testing.Result, expected_value: str, expected_exit_code: int
1790     ) -> None:
1791         """Helper method to test the value and exit code of a click Result."""
1792         assert (
1793             result.output == expected_value
1794         ), "The output did not match the expected value."
1795         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1796
1797     def test_code_option(self) -> None:
1798         """Test the code option with no changes."""
1799         code = 'print("Hello world")\n'
1800         args = ["--code", code]
1801         result = CliRunner().invoke(black.main, args)
1802
1803         self.compare_results(result, code, 0)
1804
1805     def test_code_option_changed(self) -> None:
1806         """Test the code option when changes are required."""
1807         code = "print('hello world')"
1808         formatted = black.format_str(code, mode=DEFAULT_MODE)
1809
1810         args = ["--code", code]
1811         result = CliRunner().invoke(black.main, args)
1812
1813         self.compare_results(result, formatted, 0)
1814
1815     def test_code_option_check(self) -> None:
1816         """Test the code option when check is passed."""
1817         args = ["--check", "--code", 'print("Hello world")\n']
1818         result = CliRunner().invoke(black.main, args)
1819         self.compare_results(result, "", 0)
1820
1821     def test_code_option_check_changed(self) -> None:
1822         """Test the code option when changes are required, and check is passed."""
1823         args = ["--check", "--code", "print('hello world')"]
1824         result = CliRunner().invoke(black.main, args)
1825         self.compare_results(result, "", 1)
1826
1827     def test_code_option_diff(self) -> None:
1828         """Test the code option when diff is passed."""
1829         code = "print('hello world')"
1830         formatted = black.format_str(code, mode=DEFAULT_MODE)
1831         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1832
1833         args = ["--diff", "--code", code]
1834         result = CliRunner().invoke(black.main, args)
1835
1836         # Remove time from diff
1837         output = DIFF_TIME.sub("", result.output)
1838
1839         assert output == result_diff, "The output did not match the expected value."
1840         assert result.exit_code == 0, "The exit code is incorrect."
1841
1842     def test_code_option_color_diff(self) -> None:
1843         """Test the code option when color and diff are passed."""
1844         code = "print('hello world')"
1845         formatted = black.format_str(code, mode=DEFAULT_MODE)
1846
1847         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1848         result_diff = color_diff(result_diff)
1849
1850         args = ["--diff", "--color", "--code", code]
1851         result = CliRunner().invoke(black.main, args)
1852
1853         # Remove time from diff
1854         output = DIFF_TIME.sub("", result.output)
1855
1856         assert output == result_diff, "The output did not match the expected value."
1857         assert result.exit_code == 0, "The exit code is incorrect."
1858
1859     @pytest.mark.incompatible_with_mypyc
1860     def test_code_option_safe(self) -> None:
1861         """Test that the code option throws an error when the sanity checks fail."""
1862         # Patch black.assert_equivalent to ensure the sanity checks fail
1863         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1864             code = 'print("Hello world")'
1865             error_msg = f"{code}\nerror: cannot format <string>: \n"
1866
1867             args = ["--safe", "--code", code]
1868             result = CliRunner().invoke(black.main, args)
1869
1870             self.compare_results(result, error_msg, 123)
1871
1872     def test_code_option_fast(self) -> None:
1873         """Test that the code option ignores errors when the sanity checks fail."""
1874         # Patch black.assert_equivalent to ensure the sanity checks fail
1875         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1876             code = 'print("Hello world")'
1877             formatted = black.format_str(code, mode=DEFAULT_MODE)
1878
1879             args = ["--fast", "--code", code]
1880             result = CliRunner().invoke(black.main, args)
1881
1882             self.compare_results(result, formatted, 0)
1883
1884     @pytest.mark.incompatible_with_mypyc
1885     def test_code_option_config(self) -> None:
1886         """
1887         Test that the code option finds the pyproject.toml in the current directory.
1888         """
1889         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1890             args = ["--code", "print"]
1891             # This is the only directory known to contain a pyproject.toml
1892             with change_directory(PROJECT_ROOT):
1893                 CliRunner().invoke(black.main, args)
1894                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1895
1896             assert (
1897                 len(parse.mock_calls) >= 1
1898             ), "Expected config parse to be called with the current directory."
1899
1900             _, call_args, _ = parse.mock_calls[0]
1901             assert (
1902                 call_args[0].lower() == str(pyproject_path).lower()
1903             ), "Incorrect config loaded."
1904
1905     @pytest.mark.incompatible_with_mypyc
1906     def test_code_option_parent_config(self) -> None:
1907         """
1908         Test that the code option finds the pyproject.toml in the parent directory.
1909         """
1910         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1911             with change_directory(THIS_DIR):
1912                 args = ["--code", "print"]
1913                 CliRunner().invoke(black.main, args)
1914
1915                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1916                 assert (
1917                     len(parse.mock_calls) >= 1
1918                 ), "Expected config parse to be called with the current directory."
1919
1920                 _, call_args, _ = parse.mock_calls[0]
1921                 assert (
1922                     call_args[0].lower() == str(pyproject_path).lower()
1923                 ), "Incorrect config loaded."
1924
1925     def test_for_handled_unexpected_eof_error(self) -> None:
1926         """
1927         Test that an unexpected EOF SyntaxError is nicely presented.
1928         """
1929         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1930             black.lib2to3_parse("print(", {})
1931
1932         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1933
1934     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1935         with pytest.raises(AssertionError) as err:
1936             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1937
1938         err.match("--safe")
1939         # Unfortunately the SyntaxError message has changed in newer versions so we
1940         # can't match it directly.
1941         err.match("invalid character")
1942         err.match(r"\(<unknown>, line 1\)")
1943
1944
1945 class TestCaching:
1946     def test_get_cache_dir(
1947         self,
1948         tmp_path: Path,
1949         monkeypatch: pytest.MonkeyPatch,
1950     ) -> None:
1951         # Create multiple cache directories
1952         workspace1 = tmp_path / "ws1"
1953         workspace1.mkdir()
1954         workspace2 = tmp_path / "ws2"
1955         workspace2.mkdir()
1956
1957         # Force user_cache_dir to use the temporary directory for easier assertions
1958         patch_user_cache_dir = patch(
1959             target="black.cache.user_cache_dir",
1960             autospec=True,
1961             return_value=str(workspace1),
1962         )
1963
1964         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1965         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1966         with patch_user_cache_dir:
1967             assert get_cache_dir() == workspace1
1968
1969         # If it is set, use the path provided in the env var.
1970         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1971         assert get_cache_dir() == workspace2
1972
1973     def test_cache_broken_file(self) -> None:
1974         mode = DEFAULT_MODE
1975         with cache_dir() as workspace:
1976             cache_file = get_cache_file(mode)
1977             cache_file.write_text("this is not a pickle", encoding="utf-8")
1978             assert black.read_cache(mode) == {}
1979             src = (workspace / "test.py").resolve()
1980             src.write_text("print('hello')", encoding="utf-8")
1981             invokeBlack([str(src)])
1982             cache = black.read_cache(mode)
1983             assert str(src) in cache
1984
1985     def test_cache_single_file_already_cached(self) -> None:
1986         mode = DEFAULT_MODE
1987         with cache_dir() as workspace:
1988             src = (workspace / "test.py").resolve()
1989             src.write_text("print('hello')", encoding="utf-8")
1990             black.write_cache({}, [src], mode)
1991             invokeBlack([str(src)])
1992             assert src.read_text(encoding="utf-8") == "print('hello')"
1993
1994     @event_loop()
1995     def test_cache_multiple_files(self) -> None:
1996         mode = DEFAULT_MODE
1997         with cache_dir() as workspace, patch(
1998             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1999         ):
2000             one = (workspace / "one.py").resolve()
2001             one.write_text("print('hello')", encoding="utf-8")
2002             two = (workspace / "two.py").resolve()
2003             two.write_text("print('hello')", encoding="utf-8")
2004             black.write_cache({}, [one], mode)
2005             invokeBlack([str(workspace)])
2006             assert one.read_text(encoding="utf-8") == "print('hello')"
2007             assert two.read_text(encoding="utf-8") == 'print("hello")\n'
2008             cache = black.read_cache(mode)
2009             assert str(one) in cache
2010             assert str(two) in cache
2011
2012     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
2013     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
2014         mode = DEFAULT_MODE
2015         with cache_dir() as workspace:
2016             src = (workspace / "test.py").resolve()
2017             src.write_text("print('hello')", encoding="utf-8")
2018             with patch("black.read_cache") as read_cache, patch(
2019                 "black.write_cache"
2020             ) as write_cache:
2021                 cmd = [str(src), "--diff"]
2022                 if color:
2023                     cmd.append("--color")
2024                 invokeBlack(cmd)
2025                 cache_file = get_cache_file(mode)
2026                 assert cache_file.exists() is False
2027                 write_cache.assert_not_called()
2028                 read_cache.assert_not_called()
2029
2030     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
2031     @event_loop()
2032     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
2033         with cache_dir() as workspace:
2034             for tag in range(0, 4):
2035                 src = (workspace / f"test{tag}.py").resolve()
2036                 src.write_text("print('hello')", encoding="utf-8")
2037             with patch(
2038                 "black.concurrency.Manager", wraps=multiprocessing.Manager
2039             ) as mgr:
2040                 cmd = ["--diff", str(workspace)]
2041                 if color:
2042                     cmd.append("--color")
2043                 invokeBlack(cmd, exit_code=0)
2044                 # this isn't quite doing what we want, but if it _isn't_
2045                 # called then we cannot be using the lock it provides
2046                 mgr.assert_called()
2047
2048     def test_no_cache_when_stdin(self) -> None:
2049         mode = DEFAULT_MODE
2050         with cache_dir():
2051             result = CliRunner().invoke(
2052                 black.main, ["-"], input=BytesIO(b"print('hello')")
2053             )
2054             assert not result.exit_code
2055             cache_file = get_cache_file(mode)
2056             assert not cache_file.exists()
2057
2058     def test_read_cache_no_cachefile(self) -> None:
2059         mode = DEFAULT_MODE
2060         with cache_dir():
2061             assert black.read_cache(mode) == {}
2062
2063     def test_write_cache_read_cache(self) -> None:
2064         mode = DEFAULT_MODE
2065         with cache_dir() as workspace:
2066             src = (workspace / "test.py").resolve()
2067             src.touch()
2068             black.write_cache({}, [src], mode)
2069             cache = black.read_cache(mode)
2070             assert str(src) in cache
2071             assert cache[str(src)] == black.get_cache_info(src)
2072
2073     def test_filter_cached(self) -> None:
2074         with TemporaryDirectory() as workspace:
2075             path = Path(workspace)
2076             uncached = (path / "uncached").resolve()
2077             cached = (path / "cached").resolve()
2078             cached_but_changed = (path / "changed").resolve()
2079             uncached.touch()
2080             cached.touch()
2081             cached_but_changed.touch()
2082             cache = {
2083                 str(cached): black.get_cache_info(cached),
2084                 str(cached_but_changed): (0.0, 0),
2085             }
2086             todo, done = black.cache.filter_cached(
2087                 cache, {uncached, cached, cached_but_changed}
2088             )
2089             assert todo == {uncached, cached_but_changed}
2090             assert done == {cached}
2091
2092     def test_write_cache_creates_directory_if_needed(self) -> None:
2093         mode = DEFAULT_MODE
2094         with cache_dir(exists=False) as workspace:
2095             assert not workspace.exists()
2096             black.write_cache({}, [], mode)
2097             assert workspace.exists()
2098
2099     @event_loop()
2100     def test_failed_formatting_does_not_get_cached(self) -> None:
2101         mode = DEFAULT_MODE
2102         with cache_dir() as workspace, patch(
2103             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
2104         ):
2105             failing = (workspace / "failing.py").resolve()
2106             failing.write_text("not actually python", encoding="utf-8")
2107             clean = (workspace / "clean.py").resolve()
2108             clean.write_text('print("hello")\n', encoding="utf-8")
2109             invokeBlack([str(workspace)], exit_code=123)
2110             cache = black.read_cache(mode)
2111             assert str(failing) not in cache
2112             assert str(clean) in cache
2113
2114     def test_write_cache_write_fail(self) -> None:
2115         mode = DEFAULT_MODE
2116         with cache_dir(), patch.object(Path, "open") as mock:
2117             mock.side_effect = OSError
2118             black.write_cache({}, [], mode)
2119
2120     def test_read_cache_line_lengths(self) -> None:
2121         mode = DEFAULT_MODE
2122         short_mode = replace(DEFAULT_MODE, line_length=1)
2123         with cache_dir() as workspace:
2124             path = (workspace / "file.py").resolve()
2125             path.touch()
2126             black.write_cache({}, [path], mode)
2127             one = black.read_cache(mode)
2128             assert str(path) in one
2129             two = black.read_cache(short_mode)
2130             assert str(path) not in two
2131
2132
2133 def assert_collected_sources(
2134     src: Sequence[Union[str, Path]],
2135     expected: Sequence[Union[str, Path]],
2136     *,
2137     ctx: Optional[FakeContext] = None,
2138     exclude: Optional[str] = None,
2139     include: Optional[str] = None,
2140     extend_exclude: Optional[str] = None,
2141     force_exclude: Optional[str] = None,
2142     stdin_filename: Optional[str] = None,
2143 ) -> None:
2144     gs_src = tuple(str(Path(s)) for s in src)
2145     gs_expected = [Path(s) for s in expected]
2146     gs_exclude = None if exclude is None else compile_pattern(exclude)
2147     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
2148     gs_extend_exclude = (
2149         None if extend_exclude is None else compile_pattern(extend_exclude)
2150     )
2151     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
2152     collected = black.get_sources(
2153         ctx=ctx or FakeContext(),
2154         src=gs_src,
2155         quiet=False,
2156         verbose=False,
2157         include=gs_include,
2158         exclude=gs_exclude,
2159         extend_exclude=gs_extend_exclude,
2160         force_exclude=gs_force_exclude,
2161         report=black.Report(),
2162         stdin_filename=stdin_filename,
2163     )
2164     assert sorted(collected) == sorted(gs_expected)
2165
2166
2167 class TestFileCollection:
2168     def test_include_exclude(self) -> None:
2169         path = THIS_DIR / "data" / "include_exclude_tests"
2170         src = [path]
2171         expected = [
2172             Path(path / "b/dont_exclude/a.py"),
2173             Path(path / "b/dont_exclude/a.pyi"),
2174         ]
2175         assert_collected_sources(
2176             src,
2177             expected,
2178             include=r"\.pyi?$",
2179             exclude=r"/exclude/|/\.definitely_exclude/",
2180         )
2181
2182     def test_gitignore_used_as_default(self) -> None:
2183         base = Path(DATA_DIR / "include_exclude_tests")
2184         expected = [
2185             base / "b/.definitely_exclude/a.py",
2186             base / "b/.definitely_exclude/a.pyi",
2187         ]
2188         src = [base / "b/"]
2189         ctx = FakeContext()
2190         ctx.obj["root"] = base
2191         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
2192
2193     def test_gitignore_used_on_multiple_sources(self) -> None:
2194         root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
2195         expected = [
2196             root / "dir1" / "b.py",
2197             root / "dir2" / "b.py",
2198         ]
2199         ctx = FakeContext()
2200         ctx.obj["root"] = root
2201         src = [root / "dir1", root / "dir2"]
2202         assert_collected_sources(src, expected, ctx=ctx)
2203
2204     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2205     def test_exclude_for_issue_1572(self) -> None:
2206         # Exclude shouldn't touch files that were explicitly given to Black through the
2207         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
2208         # https://github.com/psf/black/issues/1572
2209         path = DATA_DIR / "include_exclude_tests"
2210         src = [path / "b/exclude/a.py"]
2211         expected = [path / "b/exclude/a.py"]
2212         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2213
2214     def test_gitignore_exclude(self) -> None:
2215         path = THIS_DIR / "data" / "include_exclude_tests"
2216         include = re.compile(r"\.pyi?$")
2217         exclude = re.compile(r"")
2218         report = black.Report()
2219         gitignore = PathSpec.from_lines(
2220             "gitwildmatch", ["exclude/", ".definitely_exclude"]
2221         )
2222         sources: List[Path] = []
2223         expected = [
2224             Path(path / "b/dont_exclude/a.py"),
2225             Path(path / "b/dont_exclude/a.pyi"),
2226         ]
2227         this_abs = THIS_DIR.resolve()
2228         sources.extend(
2229             black.gen_python_files(
2230                 path.iterdir(),
2231                 this_abs,
2232                 include,
2233                 exclude,
2234                 None,
2235                 None,
2236                 report,
2237                 {path: gitignore},
2238                 verbose=False,
2239                 quiet=False,
2240             )
2241         )
2242         assert sorted(expected) == sorted(sources)
2243
2244     def test_nested_gitignore(self) -> None:
2245         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
2246         include = re.compile(r"\.pyi?$")
2247         exclude = re.compile(r"")
2248         root_gitignore = black.files.get_gitignore(path)
2249         report = black.Report()
2250         expected: List[Path] = [
2251             Path(path / "x.py"),
2252             Path(path / "root/b.py"),
2253             Path(path / "root/c.py"),
2254             Path(path / "root/child/c.py"),
2255         ]
2256         this_abs = THIS_DIR.resolve()
2257         sources = list(
2258             black.gen_python_files(
2259                 path.iterdir(),
2260                 this_abs,
2261                 include,
2262                 exclude,
2263                 None,
2264                 None,
2265                 report,
2266                 {path: root_gitignore},
2267                 verbose=False,
2268                 quiet=False,
2269             )
2270         )
2271         assert sorted(expected) == sorted(sources)
2272
2273     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2274         # https://github.com/psf/black/issues/2598
2275         path = Path(DATA_DIR / "nested_gitignore_tests")
2276         src = Path(path / "root" / "child")
2277         expected = [src / "a.py", src / "c.py"]
2278         assert_collected_sources([src], expected)
2279
2280     def test_invalid_gitignore(self) -> None:
2281         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2282         empty_config = path / "pyproject.toml"
2283         result = BlackRunner().invoke(
2284             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2285         )
2286         assert result.exit_code == 1
2287         assert result.stderr_bytes is not None
2288
2289         gitignore = path / ".gitignore"
2290         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2291
2292     def test_invalid_nested_gitignore(self) -> None:
2293         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2294         empty_config = path / "pyproject.toml"
2295         result = BlackRunner().invoke(
2296             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2297         )
2298         assert result.exit_code == 1
2299         assert result.stderr_bytes is not None
2300
2301         gitignore = path / "a" / ".gitignore"
2302         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2303
2304     def test_gitignore_that_ignores_subfolders(self) -> None:
2305         # If gitignore with */* is in root
2306         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
2307         expected = [root / "b.py"]
2308         ctx = FakeContext()
2309         ctx.obj["root"] = root
2310         assert_collected_sources([root], expected, ctx=ctx)
2311
2312         # If .gitignore with */* is nested
2313         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2314         expected = [
2315             root / "a.py",
2316             root / "subdir" / "b.py",
2317         ]
2318         ctx = FakeContext()
2319         ctx.obj["root"] = root
2320         assert_collected_sources([root], expected, ctx=ctx)
2321
2322         # If command is executed from outer dir
2323         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2324         target = root / "subdir"
2325         expected = [target / "b.py"]
2326         ctx = FakeContext()
2327         ctx.obj["root"] = root
2328         assert_collected_sources([target], expected, ctx=ctx)
2329
2330     def test_empty_include(self) -> None:
2331         path = DATA_DIR / "include_exclude_tests"
2332         src = [path]
2333         expected = [
2334             Path(path / "b/exclude/a.pie"),
2335             Path(path / "b/exclude/a.py"),
2336             Path(path / "b/exclude/a.pyi"),
2337             Path(path / "b/dont_exclude/a.pie"),
2338             Path(path / "b/dont_exclude/a.py"),
2339             Path(path / "b/dont_exclude/a.pyi"),
2340             Path(path / "b/.definitely_exclude/a.pie"),
2341             Path(path / "b/.definitely_exclude/a.py"),
2342             Path(path / "b/.definitely_exclude/a.pyi"),
2343             Path(path / ".gitignore"),
2344             Path(path / "pyproject.toml"),
2345         ]
2346         # Setting exclude explicitly to an empty string to block .gitignore usage.
2347         assert_collected_sources(src, expected, include="", exclude="")
2348
2349     def test_extend_exclude(self) -> None:
2350         path = DATA_DIR / "include_exclude_tests"
2351         src = [path]
2352         expected = [
2353             Path(path / "b/exclude/a.py"),
2354             Path(path / "b/dont_exclude/a.py"),
2355         ]
2356         assert_collected_sources(
2357             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2358         )
2359
2360     @pytest.mark.incompatible_with_mypyc
2361     def test_symlink_out_of_root_directory(self) -> None:
2362         path = MagicMock()
2363         root = THIS_DIR.resolve()
2364         child = MagicMock()
2365         include = re.compile(black.DEFAULT_INCLUDES)
2366         exclude = re.compile(black.DEFAULT_EXCLUDES)
2367         report = black.Report()
2368         gitignore = PathSpec.from_lines("gitwildmatch", [])
2369         # `child` should behave like a symlink which resolved path is clearly
2370         # outside of the `root` directory.
2371         path.iterdir.return_value = [child]
2372         child.resolve.return_value = Path("/a/b/c")
2373         child.as_posix.return_value = "/a/b/c"
2374         try:
2375             list(
2376                 black.gen_python_files(
2377                     path.iterdir(),
2378                     root,
2379                     include,
2380                     exclude,
2381                     None,
2382                     None,
2383                     report,
2384                     {path: gitignore},
2385                     verbose=False,
2386                     quiet=False,
2387                 )
2388             )
2389         except ValueError as ve:
2390             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2391         path.iterdir.assert_called_once()
2392         child.resolve.assert_called_once()
2393
2394     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2395     def test_get_sources_with_stdin(self) -> None:
2396         src = ["-"]
2397         expected = ["-"]
2398         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2399
2400     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2401     def test_get_sources_with_stdin_filename(self) -> None:
2402         src = ["-"]
2403         stdin_filename = str(THIS_DIR / "data/collections.py")
2404         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2405         assert_collected_sources(
2406             src,
2407             expected,
2408             exclude=r"/exclude/a\.py",
2409             stdin_filename=stdin_filename,
2410         )
2411
2412     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2413     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2414         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2415         # file being passed directly. This is the same as
2416         # test_exclude_for_issue_1572
2417         path = DATA_DIR / "include_exclude_tests"
2418         src = ["-"]
2419         stdin_filename = str(path / "b/exclude/a.py")
2420         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2421         assert_collected_sources(
2422             src,
2423             expected,
2424             exclude=r"/exclude/|a\.py",
2425             stdin_filename=stdin_filename,
2426         )
2427
2428     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2429     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2430         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2431         # file being passed directly. This is the same as
2432         # test_exclude_for_issue_1572
2433         src = ["-"]
2434         path = THIS_DIR / "data" / "include_exclude_tests"
2435         stdin_filename = str(path / "b/exclude/a.py")
2436         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2437         assert_collected_sources(
2438             src,
2439             expected,
2440             extend_exclude=r"/exclude/|a\.py",
2441             stdin_filename=stdin_filename,
2442         )
2443
2444     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2445     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2446         # Force exclude should exclude the file when passing it through
2447         # stdin_filename
2448         path = THIS_DIR / "data" / "include_exclude_tests"
2449         stdin_filename = str(path / "b/exclude/a.py")
2450         assert_collected_sources(
2451             src=["-"],
2452             expected=[],
2453             force_exclude=r"/exclude/|a\.py",
2454             stdin_filename=stdin_filename,
2455         )
2456
2457
2458 try:
2459     with open(black.__file__, "r", encoding="utf-8") as _bf:
2460         black_source_lines = _bf.readlines()
2461 except UnicodeDecodeError:
2462     if not black.COMPILED:
2463         raise
2464
2465
2466 def tracefunc(
2467     frame: types.FrameType, event: str, arg: Any
2468 ) -> Callable[[types.FrameType, str, Any], Any]:
2469     """Show function calls `from black/__init__.py` as they happen.
2470
2471     Register this with `sys.settrace()` in a test you're debugging.
2472     """
2473     if event != "call":
2474         return tracefunc
2475
2476     stack = len(inspect.stack()) - 19
2477     stack *= 2
2478     filename = frame.f_code.co_filename
2479     lineno = frame.f_lineno
2480     func_sig_lineno = lineno - 1
2481     funcname = black_source_lines[func_sig_lineno].strip()
2482     while funcname.startswith("@"):
2483         func_sig_lineno += 1
2484         funcname = black_source_lines[func_sig_lineno].strip()
2485     if "black/__init__.py" in filename:
2486         print(f"{' ' * stack}{lineno}:{funcname}")
2487     return tracefunc