]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

c665eee3a6c20644e72a9728309c4c0f95e2b24d
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     Type,
28     TypeVar,
29     Union,
30 )
31 from unittest.mock import MagicMock, patch
32
33 import click
34 import pytest
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import FileData, get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     get_case_path,
63     read_data,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         self.params: Dict[str, Any] = {}
107         # Dummy root, since most of the tests don't care about it
108         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
109
110
111 class FakeParameter(click.Parameter):
112     """A fake click Parameter for when calling functions that need it."""
113
114     def __init__(self) -> None:
115         pass
116
117
118 class BlackRunner(CliRunner):
119     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
120
121     def __init__(self) -> None:
122         super().__init__(mix_stderr=False)
123
124
125 def invokeBlack(
126     args: List[str], exit_code: int = 0, ignore_config: bool = True
127 ) -> None:
128     runner = BlackRunner()
129     if ignore_config:
130         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
131     result = runner.invoke(black.main, args, catch_exceptions=False)
132     assert result.stdout_bytes is not None
133     assert result.stderr_bytes is not None
134     msg = (
135         f"Failed with args: {args}\n"
136         f"stdout: {result.stdout_bytes.decode()!r}\n"
137         f"stderr: {result.stderr_bytes.decode()!r}\n"
138         f"exception: {result.exception}"
139     )
140     assert result.exit_code == exit_code, msg
141
142
143 class BlackTestCase(BlackBaseTestCase):
144     invokeBlack = staticmethod(invokeBlack)
145
146     def test_empty_ff(self) -> None:
147         expected = ""
148         tmp_file = Path(black.dump_to_file())
149         try:
150             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
151             actual = tmp_file.read_text(encoding="utf-8")
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     @patch("black.dump_to_file", dump_to_stderr)
157     def test_one_empty_line(self) -> None:
158         mode = black.Mode(preview=True)
159         for nl in ["\n", "\r\n"]:
160             source = expected = nl
161             assert_format(source, expected, mode=mode)
162
163     def test_one_empty_line_ff(self) -> None:
164         mode = black.Mode(preview=True)
165         for nl in ["\n", "\r\n"]:
166             expected = nl
167             tmp_file = Path(black.dump_to_file(nl))
168             if system() == "Windows":
169                 # Writing files in text mode automatically uses the system newline,
170                 # but in this case we don't want this for testing reasons. See:
171                 # https://github.com/psf/black/pull/3348
172                 with open(tmp_file, "wb") as f:
173                     f.write(nl.encode("utf-8"))
174             try:
175                 self.assertFalse(
176                     ff(tmp_file, mode=mode, write_back=black.WriteBack.YES)
177                 )
178                 with open(tmp_file, "rb") as f:
179                     actual = f.read().decode("utf-8")
180             finally:
181                 os.unlink(tmp_file)
182             self.assertFormatEqual(expected, actual)
183
184     def test_experimental_string_processing_warns(self) -> None:
185         self.assertWarns(
186             black.mode.Deprecated, black.Mode, experimental_string_processing=True
187         )
188
189     def test_piping(self) -> None:
190         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
191         result = BlackRunner().invoke(
192             black.main,
193             [
194                 "-",
195                 "--fast",
196                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
197                 f"--config={EMPTY_CONFIG}",
198             ],
199             input=BytesIO(source.encode("utf-8")),
200         )
201         self.assertEqual(result.exit_code, 0)
202         self.assertFormatEqual(expected, result.output)
203         if source != result.output:
204             black.assert_equivalent(source, result.output)
205             black.assert_stable(source, result.output, DEFAULT_MODE)
206
207     def test_piping_diff(self) -> None:
208         diff_header = re.compile(
209             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d"
210             r"\+\d\d:\d\d"
211         )
212         source, _ = read_data("simple_cases", "expression.py")
213         expected, _ = read_data("simple_cases", "expression.diff")
214         args = [
215             "-",
216             "--fast",
217             f"--line-length={black.DEFAULT_LINE_LENGTH}",
218             "--diff",
219             f"--config={EMPTY_CONFIG}",
220         ]
221         result = BlackRunner().invoke(
222             black.main, args, input=BytesIO(source.encode("utf-8"))
223         )
224         self.assertEqual(result.exit_code, 0)
225         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
226         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
227         self.assertEqual(expected, actual)
228
229     def test_piping_diff_with_color(self) -> None:
230         source, _ = read_data("simple_cases", "expression.py")
231         args = [
232             "-",
233             "--fast",
234             f"--line-length={black.DEFAULT_LINE_LENGTH}",
235             "--diff",
236             "--color",
237             f"--config={EMPTY_CONFIG}",
238         ]
239         result = BlackRunner().invoke(
240             black.main, args, input=BytesIO(source.encode("utf-8"))
241         )
242         actual = result.output
243         # Again, the contents are checked in a different test, so only look for colors.
244         self.assertIn("\033[1m", actual)
245         self.assertIn("\033[36m", actual)
246         self.assertIn("\033[32m", actual)
247         self.assertIn("\033[31m", actual)
248         self.assertIn("\033[0m", actual)
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def _test_wip(self) -> None:
252         source, expected = read_data("miscellaneous", "wip")
253         sys.settrace(tracefunc)
254         mode = replace(
255             DEFAULT_MODE,
256             experimental_string_processing=False,
257             target_versions={black.TargetVersion.PY38},
258         )
259         actual = fs(source, mode=mode)
260         sys.settrace(None)
261         self.assertFormatEqual(expected, actual)
262         black.assert_equivalent(source, actual)
263         black.assert_stable(source, actual, black.FileMode())
264
265     def test_pep_572_version_detection(self) -> None:
266         source, _ = read_data("py_38", "pep_572")
267         root = black.lib2to3_parse(source)
268         features = black.get_features_used(root)
269         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
270         versions = black.detect_target_versions(root)
271         self.assertIn(black.TargetVersion.PY38, versions)
272
273     def test_pep_695_version_detection(self) -> None:
274         for file in ("type_aliases", "type_params"):
275             source, _ = read_data("py_312", file)
276             root = black.lib2to3_parse(source)
277             features = black.get_features_used(root)
278             self.assertIn(black.Feature.TYPE_PARAMS, features)
279             versions = black.detect_target_versions(root)
280             self.assertIn(black.TargetVersion.PY312, versions)
281
282     def test_expression_ff(self) -> None:
283         source, expected = read_data("simple_cases", "expression.py")
284         tmp_file = Path(black.dump_to_file(source))
285         try:
286             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
287             actual = tmp_file.read_text(encoding="utf-8")
288         finally:
289             os.unlink(tmp_file)
290         self.assertFormatEqual(expected, actual)
291         with patch("black.dump_to_file", dump_to_stderr):
292             black.assert_equivalent(source, actual)
293             black.assert_stable(source, actual, DEFAULT_MODE)
294
295     def test_expression_diff(self) -> None:
296         source, _ = read_data("simple_cases", "expression.py")
297         expected, _ = read_data("simple_cases", "expression.diff")
298         tmp_file = Path(black.dump_to_file(source))
299         diff_header = re.compile(
300             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
301             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
302         )
303         try:
304             result = BlackRunner().invoke(
305                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
306             )
307             self.assertEqual(result.exit_code, 0)
308         finally:
309             os.unlink(tmp_file)
310         actual = result.output
311         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
312         if expected != actual:
313             dump = black.dump_to_file(actual)
314             msg = (
315                 "Expected diff isn't equal to the actual. If you made changes to"
316                 " expression.py and this is an anticipated difference, overwrite"
317                 f" tests/data/expression.diff with {dump}"
318             )
319             self.assertEqual(expected, actual, msg)
320
321     def test_expression_diff_with_color(self) -> None:
322         source, _ = read_data("simple_cases", "expression.py")
323         expected, _ = read_data("simple_cases", "expression.diff")
324         tmp_file = Path(black.dump_to_file(source))
325         try:
326             result = BlackRunner().invoke(
327                 black.main,
328                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
329             )
330         finally:
331             os.unlink(tmp_file)
332         actual = result.output
333         # We check the contents of the diff in `test_expression_diff`. All
334         # we need to check here is that color codes exist in the result.
335         self.assertIn("\033[1m", actual)
336         self.assertIn("\033[36m", actual)
337         self.assertIn("\033[32m", actual)
338         self.assertIn("\033[31m", actual)
339         self.assertIn("\033[0m", actual)
340
341     def test_detect_pos_only_arguments(self) -> None:
342         source, _ = read_data("py_38", "pep_570")
343         root = black.lib2to3_parse(source)
344         features = black.get_features_used(root)
345         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
346         versions = black.detect_target_versions(root)
347         self.assertIn(black.TargetVersion.PY38, versions)
348
349     def test_detect_debug_f_strings(self) -> None:
350         root = black.lib2to3_parse("""f"{x=}" """)
351         features = black.get_features_used(root)
352         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
353         versions = black.detect_target_versions(root)
354         self.assertIn(black.TargetVersion.PY38, versions)
355
356         root = black.lib2to3_parse(
357             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
358         )
359         features = black.get_features_used(root)
360         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
361
362         # We don't yet support feature version detection in nested f-strings
363         root = black.lib2to3_parse(
364             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
365         )
366         features = black.get_features_used(root)
367         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
368
369     @patch("black.dump_to_file", dump_to_stderr)
370     def test_string_quotes(self) -> None:
371         source, expected = read_data("miscellaneous", "string_quotes")
372         mode = black.Mode(preview=True)
373         assert_format(source, expected, mode)
374         mode = replace(mode, string_normalization=False)
375         not_normalized = fs(source, mode=mode)
376         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
377         black.assert_equivalent(source, not_normalized)
378         black.assert_stable(source, not_normalized, mode=mode)
379
380     def test_skip_source_first_line(self) -> None:
381         source, _ = read_data("miscellaneous", "invalid_header")
382         tmp_file = Path(black.dump_to_file(source))
383         # Full source should fail (invalid syntax at header)
384         self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
385         # So, skipping the first line should work
386         result = BlackRunner().invoke(
387             black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
388         )
389         self.assertEqual(result.exit_code, 0)
390         actual = tmp_file.read_text(encoding="utf-8")
391         self.assertFormatEqual(source, actual)
392
393     def test_skip_source_first_line_when_mixing_newlines(self) -> None:
394         code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
395         expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
396         with TemporaryDirectory() as workspace:
397             test_file = Path(workspace) / "skip_header.py"
398             test_file.write_bytes(code_mixing_newlines)
399             mode = replace(DEFAULT_MODE, skip_source_first_line=True)
400             ff(test_file, mode=mode, write_back=black.WriteBack.YES)
401             self.assertEqual(test_file.read_bytes(), expected)
402
403     def test_skip_magic_trailing_comma(self) -> None:
404         source, _ = read_data("simple_cases", "expression")
405         expected, _ = read_data(
406             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
407         )
408         tmp_file = Path(black.dump_to_file(source))
409         diff_header = re.compile(
410             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
411             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
412         )
413         try:
414             result = BlackRunner().invoke(
415                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
416             )
417             self.assertEqual(result.exit_code, 0)
418         finally:
419             os.unlink(tmp_file)
420         actual = result.output
421         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
422         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
423         if expected != actual:
424             dump = black.dump_to_file(actual)
425             msg = (
426                 "Expected diff isn't equal to the actual. If you made changes to"
427                 " expression.py and this is an anticipated difference, overwrite"
428                 " tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff"
429                 f" with {dump}"
430             )
431             self.assertEqual(expected, actual, msg)
432
433     @patch("black.dump_to_file", dump_to_stderr)
434     def test_async_as_identifier(self) -> None:
435         source_path = get_case_path("miscellaneous", "async_as_identifier")
436         source, expected = read_data_from_file(source_path)
437         actual = fs(source)
438         self.assertFormatEqual(expected, actual)
439         major, minor = sys.version_info[:2]
440         if major < 3 or (major <= 3 and minor < 7):
441             black.assert_equivalent(source, actual)
442         black.assert_stable(source, actual, DEFAULT_MODE)
443         # ensure black can parse this when the target is 3.6
444         self.invokeBlack([str(source_path), "--target-version", "py36"])
445         # but not on 3.7, because async/await is no longer an identifier
446         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_python37(self) -> None:
450         source_path = get_case_path("py_37", "python37")
451         source, expected = read_data_from_file(source_path)
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         major, minor = sys.version_info[:2]
455         if major > 3 or (major == 3 and minor >= 7):
456             black.assert_equivalent(source, actual)
457         black.assert_stable(source, actual, DEFAULT_MODE)
458         # ensure black can parse this when the target is 3.7
459         self.invokeBlack([str(source_path), "--target-version", "py37"])
460         # but not on 3.6, because we use async as a reserved keyword
461         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
462
463     def test_tab_comment_indentation(self) -> None:
464         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
465         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
466         self.assertFormatEqual(contents_spc, fs(contents_spc))
467         self.assertFormatEqual(contents_spc, fs(contents_tab))
468
469         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
470         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
471         self.assertFormatEqual(contents_spc, fs(contents_spc))
472         self.assertFormatEqual(contents_spc, fs(contents_tab))
473
474         # mixed tabs and spaces (valid Python 2 code)
475         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
476         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
477         self.assertFormatEqual(contents_spc, fs(contents_spc))
478         self.assertFormatEqual(contents_spc, fs(contents_tab))
479
480         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
481         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
482         self.assertFormatEqual(contents_spc, fs(contents_spc))
483         self.assertFormatEqual(contents_spc, fs(contents_tab))
484
485     def test_false_positive_symlink_output_issue_3384(self) -> None:
486         # Emulate the behavior when using the CLI (`black ./child  --verbose`), which
487         # involves patching some `pathlib.Path` methods. In particular, `is_dir` is
488         # patched only on its first call: when checking if "./child" is a directory it
489         # should return True. The "./child" folder exists relative to the cwd when
490         # running from CLI, but fails when running the tests because cwd is different
491         project_root = Path(THIS_DIR / "data" / "nested_gitignore_tests")
492         working_directory = project_root / "root"
493         target_abspath = working_directory / "child"
494         target_contents = list(target_abspath.iterdir())
495
496         def mock_n_calls(responses: List[bool]) -> Callable[[], bool]:
497             def _mocked_calls() -> bool:
498                 if responses:
499                     return responses.pop(0)
500                 return False
501
502             return _mocked_calls
503
504         with patch("pathlib.Path.iterdir", return_value=target_contents), patch(
505             "pathlib.Path.cwd", return_value=working_directory
506         ), patch("pathlib.Path.is_dir", side_effect=mock_n_calls([True])):
507             # Note that the root folder (project_root) isn't the folder
508             # named "root" (aka working_directory)
509             report = MagicMock(verbose=True)
510             black.get_sources(
511                 root=project_root,
512                 src=("./child",),
513                 quiet=False,
514                 verbose=True,
515                 include=DEFAULT_INCLUDE,
516                 exclude=None,
517                 report=report,
518                 extend_exclude=None,
519                 force_exclude=None,
520                 stdin_filename=None,
521             )
522         assert not any(
523             mock_args[1].startswith("is a symbolic link that points outside")
524             for _, mock_args, _ in report.path_ignored.mock_calls
525         ), "A symbolic link was reported."
526         report.path_ignored.assert_called_once_with(
527             Path("root", "child", "b.py"), "matches a .gitignore file content"
528         )
529
530     def test_report_verbose(self) -> None:
531         report = Report(verbose=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 1)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
546             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
547             self.assertEqual(report.return_code, 0)
548             report.done(Path("f2"), black.Changed.YES)
549             self.assertEqual(len(out_lines), 2)
550             self.assertEqual(len(err_lines), 0)
551             self.assertEqual(out_lines[-1], "reformatted f2")
552             self.assertEqual(
553                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
554             )
555             report.done(Path("f3"), black.Changed.CACHED)
556             self.assertEqual(len(out_lines), 3)
557             self.assertEqual(len(err_lines), 0)
558             self.assertEqual(
559                 out_lines[-1], "f3 wasn't modified on disk since last run."
560             )
561             self.assertEqual(
562                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
563             )
564             self.assertEqual(report.return_code, 0)
565             report.check = True
566             self.assertEqual(report.return_code, 1)
567             report.check = False
568             report.failed(Path("e1"), "boom")
569             self.assertEqual(len(out_lines), 3)
570             self.assertEqual(len(err_lines), 1)
571             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
572             self.assertEqual(
573                 unstyle(str(report)),
574                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
575                 " reformat.",
576             )
577             self.assertEqual(report.return_code, 123)
578             report.done(Path("f3"), black.Changed.YES)
579             self.assertEqual(len(out_lines), 4)
580             self.assertEqual(len(err_lines), 1)
581             self.assertEqual(out_lines[-1], "reformatted f3")
582             self.assertEqual(
583                 unstyle(str(report)),
584                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
585                 " reformat.",
586             )
587             self.assertEqual(report.return_code, 123)
588             report.failed(Path("e2"), "boom")
589             self.assertEqual(len(out_lines), 4)
590             self.assertEqual(len(err_lines), 2)
591             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
592             self.assertEqual(
593                 unstyle(str(report)),
594                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
595                 " reformat.",
596             )
597             self.assertEqual(report.return_code, 123)
598             report.path_ignored(Path("wat"), "no match")
599             self.assertEqual(len(out_lines), 5)
600             self.assertEqual(len(err_lines), 2)
601             self.assertEqual(out_lines[-1], "wat ignored: no match")
602             self.assertEqual(
603                 unstyle(str(report)),
604                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
605                 " reformat.",
606             )
607             self.assertEqual(report.return_code, 123)
608             report.done(Path("f4"), black.Changed.NO)
609             self.assertEqual(len(out_lines), 6)
610             self.assertEqual(len(err_lines), 2)
611             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
612             self.assertEqual(
613                 unstyle(str(report)),
614                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
615                 " reformat.",
616             )
617             self.assertEqual(report.return_code, 123)
618             report.check = True
619             self.assertEqual(
620                 unstyle(str(report)),
621                 "2 files would be reformatted, 3 files would be left unchanged, 2"
622                 " files would fail to reformat.",
623             )
624             report.check = False
625             report.diff = True
626             self.assertEqual(
627                 unstyle(str(report)),
628                 "2 files would be reformatted, 3 files would be left unchanged, 2"
629                 " files would fail to reformat.",
630             )
631
632     def test_report_quiet(self) -> None:
633         report = Report(quiet=True)
634         out_lines = []
635         err_lines = []
636
637         def out(msg: str, **kwargs: Any) -> None:
638             out_lines.append(msg)
639
640         def err(msg: str, **kwargs: Any) -> None:
641             err_lines.append(msg)
642
643         with patch("black.output._out", out), patch("black.output._err", err):
644             report.done(Path("f1"), black.Changed.NO)
645             self.assertEqual(len(out_lines), 0)
646             self.assertEqual(len(err_lines), 0)
647             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
648             self.assertEqual(report.return_code, 0)
649             report.done(Path("f2"), black.Changed.YES)
650             self.assertEqual(len(out_lines), 0)
651             self.assertEqual(len(err_lines), 0)
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
654             )
655             report.done(Path("f3"), black.Changed.CACHED)
656             self.assertEqual(len(out_lines), 0)
657             self.assertEqual(len(err_lines), 0)
658             self.assertEqual(
659                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
660             )
661             self.assertEqual(report.return_code, 0)
662             report.check = True
663             self.assertEqual(report.return_code, 1)
664             report.check = False
665             report.failed(Path("e1"), "boom")
666             self.assertEqual(len(out_lines), 0)
667             self.assertEqual(len(err_lines), 1)
668             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
672                 " reformat.",
673             )
674             self.assertEqual(report.return_code, 123)
675             report.done(Path("f3"), black.Changed.YES)
676             self.assertEqual(len(out_lines), 0)
677             self.assertEqual(len(err_lines), 1)
678             self.assertEqual(
679                 unstyle(str(report)),
680                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
681                 " reformat.",
682             )
683             self.assertEqual(report.return_code, 123)
684             report.failed(Path("e2"), "boom")
685             self.assertEqual(len(out_lines), 0)
686             self.assertEqual(len(err_lines), 2)
687             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
691                 " reformat.",
692             )
693             self.assertEqual(report.return_code, 123)
694             report.path_ignored(Path("wat"), "no match")
695             self.assertEqual(len(out_lines), 0)
696             self.assertEqual(len(err_lines), 2)
697             self.assertEqual(
698                 unstyle(str(report)),
699                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
700                 " reformat.",
701             )
702             self.assertEqual(report.return_code, 123)
703             report.done(Path("f4"), black.Changed.NO)
704             self.assertEqual(len(out_lines), 0)
705             self.assertEqual(len(err_lines), 2)
706             self.assertEqual(
707                 unstyle(str(report)),
708                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
709                 " reformat.",
710             )
711             self.assertEqual(report.return_code, 123)
712             report.check = True
713             self.assertEqual(
714                 unstyle(str(report)),
715                 "2 files would be reformatted, 3 files would be left unchanged, 2"
716                 " files would fail to reformat.",
717             )
718             report.check = False
719             report.diff = True
720             self.assertEqual(
721                 unstyle(str(report)),
722                 "2 files would be reformatted, 3 files would be left unchanged, 2"
723                 " files would fail to reformat.",
724             )
725
726     def test_report_normal(self) -> None:
727         report = black.Report()
728         out_lines = []
729         err_lines = []
730
731         def out(msg: str, **kwargs: Any) -> None:
732             out_lines.append(msg)
733
734         def err(msg: str, **kwargs: Any) -> None:
735             err_lines.append(msg)
736
737         with patch("black.output._out", out), patch("black.output._err", err):
738             report.done(Path("f1"), black.Changed.NO)
739             self.assertEqual(len(out_lines), 0)
740             self.assertEqual(len(err_lines), 0)
741             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
742             self.assertEqual(report.return_code, 0)
743             report.done(Path("f2"), black.Changed.YES)
744             self.assertEqual(len(out_lines), 1)
745             self.assertEqual(len(err_lines), 0)
746             self.assertEqual(out_lines[-1], "reformatted f2")
747             self.assertEqual(
748                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
749             )
750             report.done(Path("f3"), black.Changed.CACHED)
751             self.assertEqual(len(out_lines), 1)
752             self.assertEqual(len(err_lines), 0)
753             self.assertEqual(out_lines[-1], "reformatted f2")
754             self.assertEqual(
755                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
756             )
757             self.assertEqual(report.return_code, 0)
758             report.check = True
759             self.assertEqual(report.return_code, 1)
760             report.check = False
761             report.failed(Path("e1"), "boom")
762             self.assertEqual(len(out_lines), 1)
763             self.assertEqual(len(err_lines), 1)
764             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
765             self.assertEqual(
766                 unstyle(str(report)),
767                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
768                 " reformat.",
769             )
770             self.assertEqual(report.return_code, 123)
771             report.done(Path("f3"), black.Changed.YES)
772             self.assertEqual(len(out_lines), 2)
773             self.assertEqual(len(err_lines), 1)
774             self.assertEqual(out_lines[-1], "reformatted f3")
775             self.assertEqual(
776                 unstyle(str(report)),
777                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
778                 " reformat.",
779             )
780             self.assertEqual(report.return_code, 123)
781             report.failed(Path("e2"), "boom")
782             self.assertEqual(len(out_lines), 2)
783             self.assertEqual(len(err_lines), 2)
784             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
785             self.assertEqual(
786                 unstyle(str(report)),
787                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
788                 " reformat.",
789             )
790             self.assertEqual(report.return_code, 123)
791             report.path_ignored(Path("wat"), "no match")
792             self.assertEqual(len(out_lines), 2)
793             self.assertEqual(len(err_lines), 2)
794             self.assertEqual(
795                 unstyle(str(report)),
796                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
797                 " reformat.",
798             )
799             self.assertEqual(report.return_code, 123)
800             report.done(Path("f4"), black.Changed.NO)
801             self.assertEqual(len(out_lines), 2)
802             self.assertEqual(len(err_lines), 2)
803             self.assertEqual(
804                 unstyle(str(report)),
805                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
806                 " reformat.",
807             )
808             self.assertEqual(report.return_code, 123)
809             report.check = True
810             self.assertEqual(
811                 unstyle(str(report)),
812                 "2 files would be reformatted, 3 files would be left unchanged, 2"
813                 " files would fail to reformat.",
814             )
815             report.check = False
816             report.diff = True
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files would be reformatted, 3 files would be left unchanged, 2"
820                 " files would fail to reformat.",
821             )
822
823     def test_lib2to3_parse(self) -> None:
824         with self.assertRaises(black.InvalidInput):
825             black.lib2to3_parse("invalid syntax")
826
827         straddling = "x + y"
828         black.lib2to3_parse(straddling)
829         black.lib2to3_parse(straddling, {TargetVersion.PY36})
830
831         py2_only = "print x"
832         with self.assertRaises(black.InvalidInput):
833             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
834
835         py3_only = "exec(x, end=y)"
836         black.lib2to3_parse(py3_only)
837         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
838
839     def test_get_features_used_decorator(self) -> None:
840         # Test the feature detection of new decorator syntax
841         # since this makes some test cases of test_get_features_used()
842         # fails if it fails, this is tested first so that a useful case
843         # is identified
844         simples, relaxed = read_data("miscellaneous", "decorators")
845         # skip explanation comments at the top of the file
846         for simple_test in simples.split("##")[1:]:
847             node = black.lib2to3_parse(simple_test)
848             decorator = str(node.children[0].children[0]).strip()
849             self.assertNotIn(
850                 Feature.RELAXED_DECORATORS,
851                 black.get_features_used(node),
852                 msg=(
853                     f"decorator '{decorator}' follows python<=3.8 syntax"
854                     "but is detected as 3.9+"
855                     # f"The full node is\n{node!r}"
856                 ),
857             )
858         # skip the '# output' comment at the top of the output part
859         for relaxed_test in relaxed.split("##")[1:]:
860             node = black.lib2to3_parse(relaxed_test)
861             decorator = str(node.children[0].children[0]).strip()
862             self.assertIn(
863                 Feature.RELAXED_DECORATORS,
864                 black.get_features_used(node),
865                 msg=(
866                     f"decorator '{decorator}' uses python3.9+ syntax"
867                     "but is detected as python<=3.8"
868                     # f"The full node is\n{node!r}"
869                 ),
870             )
871
872     def test_get_features_used(self) -> None:
873         node = black.lib2to3_parse("def f(*, arg): ...\n")
874         self.assertEqual(black.get_features_used(node), set())
875         node = black.lib2to3_parse("def f(*, arg,): ...\n")
876         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
877         node = black.lib2to3_parse("f(*arg,)\n")
878         self.assertEqual(
879             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
880         )
881         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
882         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
883         node = black.lib2to3_parse("123_456\n")
884         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
885         node = black.lib2to3_parse("123456\n")
886         self.assertEqual(black.get_features_used(node), set())
887         source, expected = read_data("simple_cases", "function")
888         node = black.lib2to3_parse(source)
889         expected_features = {
890             Feature.TRAILING_COMMA_IN_CALL,
891             Feature.TRAILING_COMMA_IN_DEF,
892             Feature.F_STRINGS,
893         }
894         self.assertEqual(black.get_features_used(node), expected_features)
895         node = black.lib2to3_parse(expected)
896         self.assertEqual(black.get_features_used(node), expected_features)
897         source, expected = read_data("simple_cases", "expression")
898         node = black.lib2to3_parse(source)
899         self.assertEqual(black.get_features_used(node), set())
900         node = black.lib2to3_parse(expected)
901         self.assertEqual(black.get_features_used(node), set())
902         node = black.lib2to3_parse("lambda a, /, b: ...")
903         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
904         node = black.lib2to3_parse("def fn(a, /, b): ...")
905         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
906         node = black.lib2to3_parse("def fn(): yield a, b")
907         self.assertEqual(black.get_features_used(node), set())
908         node = black.lib2to3_parse("def fn(): return a, b")
909         self.assertEqual(black.get_features_used(node), set())
910         node = black.lib2to3_parse("def fn(): yield *b, c")
911         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
912         node = black.lib2to3_parse("def fn(): return a, *b, c")
913         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
914         node = black.lib2to3_parse("x = a, *b, c")
915         self.assertEqual(black.get_features_used(node), set())
916         node = black.lib2to3_parse("x: Any = regular")
917         self.assertEqual(black.get_features_used(node), set())
918         node = black.lib2to3_parse("x: Any = (regular, regular)")
919         self.assertEqual(black.get_features_used(node), set())
920         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
921         self.assertEqual(black.get_features_used(node), set())
922         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
923         self.assertEqual(
924             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
925         )
926         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
927         self.assertEqual(black.get_features_used(node), set())
928         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
929         self.assertEqual(black.get_features_used(node), set())
930         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
931         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
932         node = black.lib2to3_parse("a[*b]")
933         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
934         node = black.lib2to3_parse("a[x, *y(), z] = t")
935         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
936         node = black.lib2to3_parse("def fn(*args: *T): pass")
937         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
938
939     def test_get_features_used_for_future_flags(self) -> None:
940         for src, features in [
941             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
942             (
943                 "from __future__ import (other, annotations)",
944                 {Feature.FUTURE_ANNOTATIONS},
945             ),
946             ("a = 1 + 2\nfrom something import annotations", set()),
947             ("from __future__ import x, y", set()),
948         ]:
949             with self.subTest(src=src, features=features):
950                 node = black.lib2to3_parse(src)
951                 future_imports = black.get_future_imports(node)
952                 self.assertEqual(
953                     black.get_features_used(node, future_imports=future_imports),
954                     features,
955                 )
956
957     def test_get_future_imports(self) -> None:
958         node = black.lib2to3_parse("\n")
959         self.assertEqual(set(), black.get_future_imports(node))
960         node = black.lib2to3_parse("from __future__ import black\n")
961         self.assertEqual({"black"}, black.get_future_imports(node))
962         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
963         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
964         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
965         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
966         node = black.lib2to3_parse(
967             "from __future__ import multiple\nfrom __future__ import imports\n"
968         )
969         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
970         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
971         self.assertEqual({"black"}, black.get_future_imports(node))
972         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
973         self.assertEqual({"black"}, black.get_future_imports(node))
974         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
975         self.assertEqual(set(), black.get_future_imports(node))
976         node = black.lib2to3_parse("from some.module import black\n")
977         self.assertEqual(set(), black.get_future_imports(node))
978         node = black.lib2to3_parse(
979             "from __future__ import unicode_literals as _unicode_literals"
980         )
981         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
982         node = black.lib2to3_parse(
983             "from __future__ import unicode_literals as _lol, print"
984         )
985         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
986
987     @pytest.mark.incompatible_with_mypyc
988     def test_debug_visitor(self) -> None:
989         source, _ = read_data("miscellaneous", "debug_visitor")
990         expected, _ = read_data("miscellaneous", "debug_visitor.out")
991         out_lines = []
992         err_lines = []
993
994         def out(msg: str, **kwargs: Any) -> None:
995             out_lines.append(msg)
996
997         def err(msg: str, **kwargs: Any) -> None:
998             err_lines.append(msg)
999
1000         with patch("black.debug.out", out):
1001             DebugVisitor.show(source)
1002         actual = "\n".join(out_lines) + "\n"
1003         log_name = ""
1004         if expected != actual:
1005             log_name = black.dump_to_file(*out_lines)
1006         self.assertEqual(
1007             expected,
1008             actual,
1009             f"AST print out is different. Actual version dumped to {log_name}",
1010         )
1011
1012     def test_format_file_contents(self) -> None:
1013         mode = DEFAULT_MODE
1014         empty = ""
1015         with self.assertRaises(black.NothingChanged):
1016             black.format_file_contents(empty, mode=mode, fast=False)
1017         just_nl = "\n"
1018         with self.assertRaises(black.NothingChanged):
1019             black.format_file_contents(just_nl, mode=mode, fast=False)
1020         same = "j = [1, 2, 3]\n"
1021         with self.assertRaises(black.NothingChanged):
1022             black.format_file_contents(same, mode=mode, fast=False)
1023         different = "j = [1,2,3]"
1024         expected = same
1025         actual = black.format_file_contents(different, mode=mode, fast=False)
1026         self.assertEqual(expected, actual)
1027         invalid = "return if you can"
1028         with self.assertRaises(black.InvalidInput) as e:
1029             black.format_file_contents(invalid, mode=mode, fast=False)
1030         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1031
1032         mode = black.Mode(preview=True)
1033         just_crlf = "\r\n"
1034         with self.assertRaises(black.NothingChanged):
1035             black.format_file_contents(just_crlf, mode=mode, fast=False)
1036         just_whitespace_nl = "\n\t\n \n\t \n \t\n\n"
1037         actual = black.format_file_contents(just_whitespace_nl, mode=mode, fast=False)
1038         self.assertEqual("\n", actual)
1039         just_whitespace_crlf = "\r\n\t\r\n \r\n\t \r\n \t\r\n\r\n"
1040         actual = black.format_file_contents(just_whitespace_crlf, mode=mode, fast=False)
1041         self.assertEqual("\r\n", actual)
1042
1043     def test_endmarker(self) -> None:
1044         n = black.lib2to3_parse("\n")
1045         self.assertEqual(n.type, black.syms.file_input)
1046         self.assertEqual(len(n.children), 1)
1047         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1048
1049     @patch("tests.conftest.PRINT_FULL_TREE", True)
1050     @patch("tests.conftest.PRINT_TREE_DIFF", False)
1051     @pytest.mark.incompatible_with_mypyc
1052     def test_assertFormatEqual_print_full_tree(self) -> None:
1053         out_lines = []
1054         err_lines = []
1055
1056         def out(msg: str, **kwargs: Any) -> None:
1057             out_lines.append(msg)
1058
1059         def err(msg: str, **kwargs: Any) -> None:
1060             err_lines.append(msg)
1061
1062         with patch("black.output._out", out), patch("black.output._err", err):
1063             with self.assertRaises(AssertionError):
1064                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1065
1066         out_str = "".join(out_lines)
1067         self.assertIn("Expected tree:", out_str)
1068         self.assertIn("Actual tree:", out_str)
1069         self.assertEqual("".join(err_lines), "")
1070
1071     @patch("tests.conftest.PRINT_FULL_TREE", False)
1072     @patch("tests.conftest.PRINT_TREE_DIFF", True)
1073     @pytest.mark.incompatible_with_mypyc
1074     def test_assertFormatEqual_print_tree_diff(self) -> None:
1075         out_lines = []
1076         err_lines = []
1077
1078         def out(msg: str, **kwargs: Any) -> None:
1079             out_lines.append(msg)
1080
1081         def err(msg: str, **kwargs: Any) -> None:
1082             err_lines.append(msg)
1083
1084         with patch("black.output._out", out), patch("black.output._err", err):
1085             with self.assertRaises(AssertionError):
1086                 self.assertFormatEqual("j = [1, 2, 3]\n", "j = [1, 2, 3,]\n")
1087
1088         out_str = "".join(out_lines)
1089         self.assertIn("Tree Diff:", out_str)
1090         self.assertIn("+          COMMA", out_str)
1091         self.assertIn("+ ','", out_str)
1092         self.assertEqual("".join(err_lines), "")
1093
1094     @event_loop()
1095     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1096     def test_works_in_mono_process_only_environment(self) -> None:
1097         with cache_dir() as workspace:
1098             for f in [
1099                 (workspace / "one.py").resolve(),
1100                 (workspace / "two.py").resolve(),
1101             ]:
1102                 f.write_text('print("hello")\n', encoding="utf-8")
1103             self.invokeBlack([str(workspace)])
1104
1105     @event_loop()
1106     def test_check_diff_use_together(self) -> None:
1107         with cache_dir():
1108             # Files which will be reformatted.
1109             src1 = get_case_path("miscellaneous", "string_quotes")
1110             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1111             # Files which will not be reformatted.
1112             src2 = get_case_path("simple_cases", "composition")
1113             self.invokeBlack([str(src2), "--diff", "--check"])
1114             # Multi file command.
1115             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1116
1117     def test_no_src_fails(self) -> None:
1118         with cache_dir():
1119             self.invokeBlack([], exit_code=1)
1120
1121     def test_src_and_code_fails(self) -> None:
1122         with cache_dir():
1123             self.invokeBlack([".", "-c", "0"], exit_code=1)
1124
1125     def test_broken_symlink(self) -> None:
1126         with cache_dir() as workspace:
1127             symlink = workspace / "broken_link.py"
1128             try:
1129                 symlink.symlink_to("nonexistent.py")
1130             except (OSError, NotImplementedError) as e:
1131                 self.skipTest(f"Can't create symlinks: {e}")
1132             self.invokeBlack([str(workspace.resolve())])
1133
1134     def test_single_file_force_pyi(self) -> None:
1135         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1136         contents, expected = read_data("miscellaneous", "force_pyi")
1137         with cache_dir() as workspace:
1138             path = (workspace / "file.py").resolve()
1139             path.write_text(contents, encoding="utf-8")
1140             self.invokeBlack([str(path), "--pyi"])
1141             actual = path.read_text(encoding="utf-8")
1142             # verify cache with --pyi is separate
1143             pyi_cache = black.Cache.read(pyi_mode)
1144             assert not pyi_cache.is_changed(path)
1145             normal_cache = black.Cache.read(DEFAULT_MODE)
1146             assert normal_cache.is_changed(path)
1147         self.assertFormatEqual(expected, actual)
1148         black.assert_equivalent(contents, actual)
1149         black.assert_stable(contents, actual, pyi_mode)
1150
1151     @event_loop()
1152     def test_multi_file_force_pyi(self) -> None:
1153         reg_mode = DEFAULT_MODE
1154         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1155         contents, expected = read_data("miscellaneous", "force_pyi")
1156         with cache_dir() as workspace:
1157             paths = [
1158                 (workspace / "file1.py").resolve(),
1159                 (workspace / "file2.py").resolve(),
1160             ]
1161             for path in paths:
1162                 path.write_text(contents, encoding="utf-8")
1163             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1164             for path in paths:
1165                 actual = path.read_text(encoding="utf-8")
1166                 self.assertEqual(actual, expected)
1167             # verify cache with --pyi is separate
1168             pyi_cache = black.Cache.read(pyi_mode)
1169             normal_cache = black.Cache.read(reg_mode)
1170             for path in paths:
1171                 assert not pyi_cache.is_changed(path)
1172                 assert normal_cache.is_changed(path)
1173
1174     def test_pipe_force_pyi(self) -> None:
1175         source, expected = read_data("miscellaneous", "force_pyi")
1176         result = CliRunner().invoke(
1177             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf-8"))
1178         )
1179         self.assertEqual(result.exit_code, 0)
1180         actual = result.output
1181         self.assertFormatEqual(actual, expected)
1182
1183     def test_single_file_force_py36(self) -> None:
1184         reg_mode = DEFAULT_MODE
1185         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1186         source, expected = read_data("miscellaneous", "force_py36")
1187         with cache_dir() as workspace:
1188             path = (workspace / "file.py").resolve()
1189             path.write_text(source, encoding="utf-8")
1190             self.invokeBlack([str(path), *PY36_ARGS])
1191             actual = path.read_text(encoding="utf-8")
1192             # verify cache with --target-version is separate
1193             py36_cache = black.Cache.read(py36_mode)
1194             assert not py36_cache.is_changed(path)
1195             normal_cache = black.Cache.read(reg_mode)
1196             assert normal_cache.is_changed(path)
1197         self.assertEqual(actual, expected)
1198
1199     @event_loop()
1200     def test_multi_file_force_py36(self) -> None:
1201         reg_mode = DEFAULT_MODE
1202         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1203         source, expected = read_data("miscellaneous", "force_py36")
1204         with cache_dir() as workspace:
1205             paths = [
1206                 (workspace / "file1.py").resolve(),
1207                 (workspace / "file2.py").resolve(),
1208             ]
1209             for path in paths:
1210                 path.write_text(source, encoding="utf-8")
1211             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1212             for path in paths:
1213                 actual = path.read_text(encoding="utf-8")
1214                 self.assertEqual(actual, expected)
1215             # verify cache with --target-version is separate
1216             pyi_cache = black.Cache.read(py36_mode)
1217             normal_cache = black.Cache.read(reg_mode)
1218             for path in paths:
1219                 assert not pyi_cache.is_changed(path)
1220                 assert normal_cache.is_changed(path)
1221
1222     def test_pipe_force_py36(self) -> None:
1223         source, expected = read_data("miscellaneous", "force_py36")
1224         result = CliRunner().invoke(
1225             black.main,
1226             ["-", "-q", "--target-version=py36"],
1227             input=BytesIO(source.encode("utf-8")),
1228         )
1229         self.assertEqual(result.exit_code, 0)
1230         actual = result.output
1231         self.assertFormatEqual(actual, expected)
1232
1233     @pytest.mark.incompatible_with_mypyc
1234     def test_reformat_one_with_stdin(self) -> None:
1235         with patch(
1236             "black.format_stdin_to_stdout",
1237             return_value=lambda *args, **kwargs: black.Changed.YES,
1238         ) as fsts:
1239             report = MagicMock()
1240             path = Path("-")
1241             black.reformat_one(
1242                 path,
1243                 fast=True,
1244                 write_back=black.WriteBack.YES,
1245                 mode=DEFAULT_MODE,
1246                 report=report,
1247             )
1248             fsts.assert_called_once()
1249             report.done.assert_called_with(path, black.Changed.YES)
1250
1251     @pytest.mark.incompatible_with_mypyc
1252     def test_reformat_one_with_stdin_filename(self) -> None:
1253         with patch(
1254             "black.format_stdin_to_stdout",
1255             return_value=lambda *args, **kwargs: black.Changed.YES,
1256         ) as fsts:
1257             report = MagicMock()
1258             p = "foo.py"
1259             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1260             expected = Path(p)
1261             black.reformat_one(
1262                 path,
1263                 fast=True,
1264                 write_back=black.WriteBack.YES,
1265                 mode=DEFAULT_MODE,
1266                 report=report,
1267             )
1268             fsts.assert_called_once_with(
1269                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1270             )
1271             # __BLACK_STDIN_FILENAME__ should have been stripped
1272             report.done.assert_called_with(expected, black.Changed.YES)
1273
1274     @pytest.mark.incompatible_with_mypyc
1275     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1276         with patch(
1277             "black.format_stdin_to_stdout",
1278             return_value=lambda *args, **kwargs: black.Changed.YES,
1279         ) as fsts:
1280             report = MagicMock()
1281             p = "foo.pyi"
1282             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1283             expected = Path(p)
1284             black.reformat_one(
1285                 path,
1286                 fast=True,
1287                 write_back=black.WriteBack.YES,
1288                 mode=DEFAULT_MODE,
1289                 report=report,
1290             )
1291             fsts.assert_called_once_with(
1292                 fast=True,
1293                 write_back=black.WriteBack.YES,
1294                 mode=replace(DEFAULT_MODE, is_pyi=True),
1295             )
1296             # __BLACK_STDIN_FILENAME__ should have been stripped
1297             report.done.assert_called_with(expected, black.Changed.YES)
1298
1299     @pytest.mark.incompatible_with_mypyc
1300     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1301         with patch(
1302             "black.format_stdin_to_stdout",
1303             return_value=lambda *args, **kwargs: black.Changed.YES,
1304         ) as fsts:
1305             report = MagicMock()
1306             p = "foo.ipynb"
1307             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1308             expected = Path(p)
1309             black.reformat_one(
1310                 path,
1311                 fast=True,
1312                 write_back=black.WriteBack.YES,
1313                 mode=DEFAULT_MODE,
1314                 report=report,
1315             )
1316             fsts.assert_called_once_with(
1317                 fast=True,
1318                 write_back=black.WriteBack.YES,
1319                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1320             )
1321             # __BLACK_STDIN_FILENAME__ should have been stripped
1322             report.done.assert_called_with(expected, black.Changed.YES)
1323
1324     @pytest.mark.incompatible_with_mypyc
1325     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1326         with patch(
1327             "black.format_stdin_to_stdout",
1328             return_value=lambda *args, **kwargs: black.Changed.YES,
1329         ) as fsts:
1330             report = MagicMock()
1331             # Even with an existing file, since we are forcing stdin, black
1332             # should output to stdout and not modify the file inplace
1333             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1334             # Make sure is_file actually returns True
1335             self.assertTrue(p.is_file())
1336             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1337             expected = Path(p)
1338             black.reformat_one(
1339                 path,
1340                 fast=True,
1341                 write_back=black.WriteBack.YES,
1342                 mode=DEFAULT_MODE,
1343                 report=report,
1344             )
1345             fsts.assert_called_once()
1346             # __BLACK_STDIN_FILENAME__ should have been stripped
1347             report.done.assert_called_with(expected, black.Changed.YES)
1348
1349     def test_reformat_one_with_stdin_empty(self) -> None:
1350         cases = [
1351             ("", ""),
1352             ("\n", "\n"),
1353             ("\r\n", "\r\n"),
1354             (" \t", ""),
1355             (" \t\n\t ", "\n"),
1356             (" \t\r\n\t ", "\r\n"),
1357         ]
1358
1359         def _new_wrapper(
1360             output: io.StringIO, io_TextIOWrapper: Type[io.TextIOWrapper]
1361         ) -> Callable[[Any, Any], io.TextIOWrapper]:
1362             def get_output(*args: Any, **kwargs: Any) -> io.TextIOWrapper:
1363                 if args == (sys.stdout.buffer,):
1364                     # It's `format_stdin_to_stdout()` calling `io.TextIOWrapper()`,
1365                     # return our mock object.
1366                     return output
1367                 # It's something else (i.e. `decode_bytes()`) calling
1368                 # `io.TextIOWrapper()`, pass through to the original implementation.
1369                 # See discussion in https://github.com/psf/black/pull/2489
1370                 return io_TextIOWrapper(*args, **kwargs)
1371
1372             return get_output
1373
1374         mode = black.Mode(preview=True)
1375         for content, expected in cases:
1376             output = io.StringIO()
1377             io_TextIOWrapper = io.TextIOWrapper
1378
1379             with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1380                 try:
1381                     black.format_stdin_to_stdout(
1382                         fast=True,
1383                         content=content,
1384                         write_back=black.WriteBack.YES,
1385                         mode=mode,
1386                     )
1387                 except io.UnsupportedOperation:
1388                     pass  # StringIO does not support detach
1389                 assert output.getvalue() == expected
1390
1391         # An empty string is the only test case for `preview=False`
1392         output = io.StringIO()
1393         io_TextIOWrapper = io.TextIOWrapper
1394         with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1395             try:
1396                 black.format_stdin_to_stdout(
1397                     fast=True,
1398                     content="",
1399                     write_back=black.WriteBack.YES,
1400                     mode=DEFAULT_MODE,
1401                 )
1402             except io.UnsupportedOperation:
1403                 pass  # StringIO does not support detach
1404             assert output.getvalue() == ""
1405
1406     def test_invalid_cli_regex(self) -> None:
1407         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1408             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1409
1410     def test_required_version_matches_version(self) -> None:
1411         self.invokeBlack(
1412             ["--required-version", black.__version__, "-c", "0"],
1413             exit_code=0,
1414             ignore_config=True,
1415         )
1416
1417     def test_required_version_matches_partial_version(self) -> None:
1418         self.invokeBlack(
1419             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1420             exit_code=0,
1421             ignore_config=True,
1422         )
1423
1424     def test_required_version_does_not_match_on_minor_version(self) -> None:
1425         self.invokeBlack(
1426             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1427             exit_code=1,
1428             ignore_config=True,
1429         )
1430
1431     def test_required_version_does_not_match_version(self) -> None:
1432         result = BlackRunner().invoke(
1433             black.main,
1434             ["--required-version", "20.99b", "-c", "0"],
1435         )
1436         self.assertEqual(result.exit_code, 1)
1437         self.assertIn("required version", result.stderr)
1438
1439     def test_preserves_line_endings(self) -> None:
1440         with TemporaryDirectory() as workspace:
1441             test_file = Path(workspace) / "test.py"
1442             for nl in ["\n", "\r\n"]:
1443                 contents = nl.join(["def f(  ):", "    pass"])
1444                 test_file.write_bytes(contents.encode())
1445                 ff(test_file, write_back=black.WriteBack.YES)
1446                 updated_contents: bytes = test_file.read_bytes()
1447                 self.assertIn(nl.encode(), updated_contents)
1448                 if nl == "\n":
1449                     self.assertNotIn(b"\r\n", updated_contents)
1450
1451     def test_preserves_line_endings_via_stdin(self) -> None:
1452         for nl in ["\n", "\r\n"]:
1453             contents = nl.join(["def f(  ):", "    pass"])
1454             runner = BlackRunner()
1455             result = runner.invoke(
1456                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf-8"))
1457             )
1458             self.assertEqual(result.exit_code, 0)
1459             output = result.stdout_bytes
1460             self.assertIn(nl.encode("utf-8"), output)
1461             if nl == "\n":
1462                 self.assertNotIn(b"\r\n", output)
1463
1464     def test_normalize_line_endings(self) -> None:
1465         with TemporaryDirectory() as workspace:
1466             test_file = Path(workspace) / "test.py"
1467             for data, expected in (
1468                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1469                 (b"l\nl\r\n ", b"l\nl\n"),
1470             ):
1471                 test_file.write_bytes(data)
1472                 ff(test_file, write_back=black.WriteBack.YES)
1473                 self.assertEqual(test_file.read_bytes(), expected)
1474
1475     def test_assert_equivalent_different_asts(self) -> None:
1476         with self.assertRaises(AssertionError):
1477             black.assert_equivalent("{}", "None")
1478
1479     def test_root_logger_not_used_directly(self) -> None:
1480         def fail(*args: Any, **kwargs: Any) -> None:
1481             self.fail("Record created with root logger")
1482
1483         with patch.multiple(
1484             logging.root,
1485             debug=fail,
1486             info=fail,
1487             warning=fail,
1488             error=fail,
1489             critical=fail,
1490             log=fail,
1491         ):
1492             ff(THIS_DIR / "util.py")
1493
1494     def test_invalid_config_return_code(self) -> None:
1495         tmp_file = Path(black.dump_to_file())
1496         try:
1497             tmp_config = Path(black.dump_to_file())
1498             tmp_config.unlink()
1499             args = ["--config", str(tmp_config), str(tmp_file)]
1500             self.invokeBlack(args, exit_code=2, ignore_config=False)
1501         finally:
1502             tmp_file.unlink()
1503
1504     def test_parse_pyproject_toml(self) -> None:
1505         test_toml_file = THIS_DIR / "test.toml"
1506         config = black.parse_pyproject_toml(str(test_toml_file))
1507         self.assertEqual(config["verbose"], 1)
1508         self.assertEqual(config["check"], "no")
1509         self.assertEqual(config["diff"], "y")
1510         self.assertEqual(config["color"], True)
1511         self.assertEqual(config["line_length"], 79)
1512         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1513         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1514         self.assertEqual(config["exclude"], r"\.pyi?$")
1515         self.assertEqual(config["include"], r"\.py?$")
1516
1517     def test_parse_pyproject_toml_project_metadata(self) -> None:
1518         for test_toml, expected in [
1519             ("only_black_pyproject.toml", ["py310"]),
1520             ("only_metadata_pyproject.toml", ["py37", "py38", "py39", "py310"]),
1521             ("neither_pyproject.toml", None),
1522             ("both_pyproject.toml", ["py310"]),
1523         ]:
1524             test_toml_file = THIS_DIR / "data" / "project_metadata" / test_toml
1525             config = black.parse_pyproject_toml(str(test_toml_file))
1526             self.assertEqual(config.get("target_version"), expected)
1527
1528     def test_infer_target_version(self) -> None:
1529         for version, expected in [
1530             ("3.6", [TargetVersion.PY36]),
1531             ("3.11.0rc1", [TargetVersion.PY311]),
1532             (">=3.10", [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312]),
1533             (
1534                 ">=3.10.6",
1535                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1536             ),
1537             ("<3.6", [TargetVersion.PY33, TargetVersion.PY34, TargetVersion.PY35]),
1538             (">3.7,<3.10", [TargetVersion.PY38, TargetVersion.PY39]),
1539             (
1540                 ">3.7,!=3.8,!=3.9",
1541                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1542             ),
1543             (
1544                 "> 3.9.4, != 3.10.3",
1545                 [
1546                     TargetVersion.PY39,
1547                     TargetVersion.PY310,
1548                     TargetVersion.PY311,
1549                     TargetVersion.PY312,
1550                 ],
1551             ),
1552             (
1553                 "!=3.3,!=3.4",
1554                 [
1555                     TargetVersion.PY35,
1556                     TargetVersion.PY36,
1557                     TargetVersion.PY37,
1558                     TargetVersion.PY38,
1559                     TargetVersion.PY39,
1560                     TargetVersion.PY310,
1561                     TargetVersion.PY311,
1562                     TargetVersion.PY312,
1563                 ],
1564             ),
1565             (
1566                 "==3.*",
1567                 [
1568                     TargetVersion.PY33,
1569                     TargetVersion.PY34,
1570                     TargetVersion.PY35,
1571                     TargetVersion.PY36,
1572                     TargetVersion.PY37,
1573                     TargetVersion.PY38,
1574                     TargetVersion.PY39,
1575                     TargetVersion.PY310,
1576                     TargetVersion.PY311,
1577                     TargetVersion.PY312,
1578                 ],
1579             ),
1580             ("==3.8.*", [TargetVersion.PY38]),
1581             (None, None),
1582             ("", None),
1583             ("invalid", None),
1584             ("==invalid", None),
1585             (">3.9,!=invalid", None),
1586             ("3", None),
1587             ("3.2", None),
1588             ("2.7.18", None),
1589             ("==2.7", None),
1590             (">3.10,<3.11", None),
1591         ]:
1592             test_toml = {"project": {"requires-python": version}}
1593             result = black.files.infer_target_version(test_toml)
1594             self.assertEqual(result, expected)
1595
1596     def test_read_pyproject_toml(self) -> None:
1597         test_toml_file = THIS_DIR / "test.toml"
1598         fake_ctx = FakeContext()
1599         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1600         config = fake_ctx.default_map
1601         self.assertEqual(config["verbose"], "1")
1602         self.assertEqual(config["check"], "no")
1603         self.assertEqual(config["diff"], "y")
1604         self.assertEqual(config["color"], "True")
1605         self.assertEqual(config["line_length"], "79")
1606         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1607         self.assertEqual(config["exclude"], r"\.pyi?$")
1608         self.assertEqual(config["include"], r"\.py?$")
1609
1610     def test_read_pyproject_toml_from_stdin(self) -> None:
1611         with TemporaryDirectory() as workspace:
1612             root = Path(workspace)
1613
1614             src_dir = root / "src"
1615             src_dir.mkdir()
1616
1617             src_pyproject = src_dir / "pyproject.toml"
1618             src_pyproject.touch()
1619
1620             test_toml_content = (THIS_DIR / "test.toml").read_text(encoding="utf-8")
1621             src_pyproject.write_text(test_toml_content, encoding="utf-8")
1622
1623             src_python = src_dir / "foo.py"
1624             src_python.touch()
1625
1626             fake_ctx = FakeContext()
1627             fake_ctx.params["src"] = ("-",)
1628             fake_ctx.params["stdin_filename"] = str(src_python)
1629
1630             with change_directory(root):
1631                 black.read_pyproject_toml(fake_ctx, FakeParameter(), None)
1632
1633             config = fake_ctx.default_map
1634             self.assertEqual(config["verbose"], "1")
1635             self.assertEqual(config["check"], "no")
1636             self.assertEqual(config["diff"], "y")
1637             self.assertEqual(config["color"], "True")
1638             self.assertEqual(config["line_length"], "79")
1639             self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1640             self.assertEqual(config["exclude"], r"\.pyi?$")
1641             self.assertEqual(config["include"], r"\.py?$")
1642
1643     @pytest.mark.incompatible_with_mypyc
1644     def test_find_project_root(self) -> None:
1645         with TemporaryDirectory() as workspace:
1646             root = Path(workspace)
1647             test_dir = root / "test"
1648             test_dir.mkdir()
1649
1650             src_dir = root / "src"
1651             src_dir.mkdir()
1652
1653             root_pyproject = root / "pyproject.toml"
1654             root_pyproject.touch()
1655             src_pyproject = src_dir / "pyproject.toml"
1656             src_pyproject.touch()
1657             src_python = src_dir / "foo.py"
1658             src_python.touch()
1659
1660             self.assertEqual(
1661                 black.find_project_root((src_dir, test_dir)),
1662                 (root.resolve(), "pyproject.toml"),
1663             )
1664             self.assertEqual(
1665                 black.find_project_root((src_dir,)),
1666                 (src_dir.resolve(), "pyproject.toml"),
1667             )
1668             self.assertEqual(
1669                 black.find_project_root((src_python,)),
1670                 (src_dir.resolve(), "pyproject.toml"),
1671             )
1672
1673             with change_directory(test_dir):
1674                 self.assertEqual(
1675                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1676                     (src_dir.resolve(), "pyproject.toml"),
1677                 )
1678
1679     @patch(
1680         "black.files.find_user_pyproject_toml",
1681     )
1682     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1683         find_user_pyproject_toml.side_effect = RuntimeError()
1684
1685         with redirect_stderr(io.StringIO()) as stderr:
1686             result = black.files.find_pyproject_toml(
1687                 path_search_start=(str(Path.cwd().root),)
1688             )
1689
1690         assert result is None
1691         err = stderr.getvalue()
1692         assert "Ignoring user configuration" in err
1693
1694     @patch(
1695         "black.files.find_user_pyproject_toml",
1696         black.files.find_user_pyproject_toml.__wrapped__,
1697     )
1698     def test_find_user_pyproject_toml_linux(self) -> None:
1699         if system() == "Windows":
1700             return
1701
1702         # Test if XDG_CONFIG_HOME is checked
1703         with TemporaryDirectory() as workspace:
1704             tmp_user_config = Path(workspace) / "black"
1705             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1706                 self.assertEqual(
1707                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1708                 )
1709
1710         # Test fallback for XDG_CONFIG_HOME
1711         with patch.dict("os.environ"):
1712             os.environ.pop("XDG_CONFIG_HOME", None)
1713             fallback_user_config = Path("~/.config").expanduser() / "black"
1714             self.assertEqual(
1715                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1716             )
1717
1718     def test_find_user_pyproject_toml_windows(self) -> None:
1719         if system() != "Windows":
1720             return
1721
1722         user_config_path = Path.home() / ".black"
1723         self.assertEqual(
1724             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1725         )
1726
1727     def test_bpo_33660_workaround(self) -> None:
1728         if system() == "Windows":
1729             return
1730
1731         # https://bugs.python.org/issue33660
1732         root = Path("/")
1733         with change_directory(root):
1734             path = Path("workspace") / "project"
1735             report = black.Report(verbose=True)
1736             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1737             self.assertEqual(normalized_path, "workspace/project")
1738
1739     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1740         if system() != "Windows":
1741             return
1742
1743         with TemporaryDirectory() as workspace:
1744             root = Path(workspace)
1745             junction_dir = root / "junction"
1746             junction_target_outside_of_root = root / ".."
1747             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1748
1749             report = black.Report(verbose=True)
1750             normalized_path = black.normalize_path_maybe_ignore(
1751                 junction_dir, root, report
1752             )
1753             # Manually delete for Python < 3.8
1754             os.system(f"rmdir {junction_dir}")
1755
1756             self.assertEqual(normalized_path, None)
1757
1758     def test_newline_comment_interaction(self) -> None:
1759         source = "class A:\\\r\n# type: ignore\n pass\n"
1760         output = black.format_str(source, mode=DEFAULT_MODE)
1761         black.assert_stable(source, output, mode=DEFAULT_MODE)
1762
1763     def test_bpo_2142_workaround(self) -> None:
1764         # https://bugs.python.org/issue2142
1765
1766         source, _ = read_data("miscellaneous", "missing_final_newline")
1767         # read_data adds a trailing newline
1768         source = source.rstrip()
1769         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1770         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1771         diff_header = re.compile(
1772             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1773             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
1774         )
1775         try:
1776             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1777             self.assertEqual(result.exit_code, 0)
1778         finally:
1779             os.unlink(tmp_file)
1780         actual = result.output
1781         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1782         self.assertEqual(actual, expected)
1783
1784     @staticmethod
1785     def compare_results(
1786         result: click.testing.Result, expected_value: str, expected_exit_code: int
1787     ) -> None:
1788         """Helper method to test the value and exit code of a click Result."""
1789         assert (
1790             result.output == expected_value
1791         ), "The output did not match the expected value."
1792         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1793
1794     def test_code_option(self) -> None:
1795         """Test the code option with no changes."""
1796         code = 'print("Hello world")\n'
1797         args = ["--code", code]
1798         result = CliRunner().invoke(black.main, args)
1799
1800         self.compare_results(result, code, 0)
1801
1802     def test_code_option_changed(self) -> None:
1803         """Test the code option when changes are required."""
1804         code = "print('hello world')"
1805         formatted = black.format_str(code, mode=DEFAULT_MODE)
1806
1807         args = ["--code", code]
1808         result = CliRunner().invoke(black.main, args)
1809
1810         self.compare_results(result, formatted, 0)
1811
1812     def test_code_option_check(self) -> None:
1813         """Test the code option when check is passed."""
1814         args = ["--check", "--code", 'print("Hello world")\n']
1815         result = CliRunner().invoke(black.main, args)
1816         self.compare_results(result, "", 0)
1817
1818     def test_code_option_check_changed(self) -> None:
1819         """Test the code option when changes are required, and check is passed."""
1820         args = ["--check", "--code", "print('hello world')"]
1821         result = CliRunner().invoke(black.main, args)
1822         self.compare_results(result, "", 1)
1823
1824     def test_code_option_diff(self) -> None:
1825         """Test the code option when diff is passed."""
1826         code = "print('hello world')"
1827         formatted = black.format_str(code, mode=DEFAULT_MODE)
1828         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1829
1830         args = ["--diff", "--code", code]
1831         result = CliRunner().invoke(black.main, args)
1832
1833         # Remove time from diff
1834         output = DIFF_TIME.sub("", result.output)
1835
1836         assert output == result_diff, "The output did not match the expected value."
1837         assert result.exit_code == 0, "The exit code is incorrect."
1838
1839     def test_code_option_color_diff(self) -> None:
1840         """Test the code option when color and diff are passed."""
1841         code = "print('hello world')"
1842         formatted = black.format_str(code, mode=DEFAULT_MODE)
1843
1844         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1845         result_diff = color_diff(result_diff)
1846
1847         args = ["--diff", "--color", "--code", code]
1848         result = CliRunner().invoke(black.main, args)
1849
1850         # Remove time from diff
1851         output = DIFF_TIME.sub("", result.output)
1852
1853         assert output == result_diff, "The output did not match the expected value."
1854         assert result.exit_code == 0, "The exit code is incorrect."
1855
1856     @pytest.mark.incompatible_with_mypyc
1857     def test_code_option_safe(self) -> None:
1858         """Test that the code option throws an error when the sanity checks fail."""
1859         # Patch black.assert_equivalent to ensure the sanity checks fail
1860         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1861             code = 'print("Hello world")'
1862             error_msg = f"{code}\nerror: cannot format <string>: \n"
1863
1864             args = ["--safe", "--code", code]
1865             result = CliRunner().invoke(black.main, args)
1866
1867             self.compare_results(result, error_msg, 123)
1868
1869     def test_code_option_fast(self) -> None:
1870         """Test that the code option ignores errors when the sanity checks fail."""
1871         # Patch black.assert_equivalent to ensure the sanity checks fail
1872         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1873             code = 'print("Hello world")'
1874             formatted = black.format_str(code, mode=DEFAULT_MODE)
1875
1876             args = ["--fast", "--code", code]
1877             result = CliRunner().invoke(black.main, args)
1878
1879             self.compare_results(result, formatted, 0)
1880
1881     @pytest.mark.incompatible_with_mypyc
1882     def test_code_option_config(self) -> None:
1883         """
1884         Test that the code option finds the pyproject.toml in the current directory.
1885         """
1886         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1887             args = ["--code", "print"]
1888             # This is the only directory known to contain a pyproject.toml
1889             with change_directory(PROJECT_ROOT):
1890                 CliRunner().invoke(black.main, args)
1891                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1892
1893             assert (
1894                 len(parse.mock_calls) >= 1
1895             ), "Expected config parse to be called with the current directory."
1896
1897             _, call_args, _ = parse.mock_calls[0]
1898             assert (
1899                 call_args[0].lower() == str(pyproject_path).lower()
1900             ), "Incorrect config loaded."
1901
1902     @pytest.mark.incompatible_with_mypyc
1903     def test_code_option_parent_config(self) -> None:
1904         """
1905         Test that the code option finds the pyproject.toml in the parent directory.
1906         """
1907         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1908             with change_directory(THIS_DIR):
1909                 args = ["--code", "print"]
1910                 CliRunner().invoke(black.main, args)
1911
1912                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1913                 assert (
1914                     len(parse.mock_calls) >= 1
1915                 ), "Expected config parse to be called with the current directory."
1916
1917                 _, call_args, _ = parse.mock_calls[0]
1918                 assert (
1919                     call_args[0].lower() == str(pyproject_path).lower()
1920                 ), "Incorrect config loaded."
1921
1922     def test_for_handled_unexpected_eof_error(self) -> None:
1923         """
1924         Test that an unexpected EOF SyntaxError is nicely presented.
1925         """
1926         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1927             black.lib2to3_parse("print(", {})
1928
1929         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1930
1931     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1932         with pytest.raises(AssertionError) as err:
1933             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1934
1935         err.match("--safe")
1936         # Unfortunately the SyntaxError message has changed in newer versions so we
1937         # can't match it directly.
1938         err.match("invalid character")
1939         err.match(r"\(<unknown>, line 1\)")
1940
1941
1942 class TestCaching:
1943     def test_get_cache_dir(
1944         self,
1945         tmp_path: Path,
1946         monkeypatch: pytest.MonkeyPatch,
1947     ) -> None:
1948         # Create multiple cache directories
1949         workspace1 = tmp_path / "ws1"
1950         workspace1.mkdir()
1951         workspace2 = tmp_path / "ws2"
1952         workspace2.mkdir()
1953
1954         # Force user_cache_dir to use the temporary directory for easier assertions
1955         patch_user_cache_dir = patch(
1956             target="black.cache.user_cache_dir",
1957             autospec=True,
1958             return_value=str(workspace1),
1959         )
1960
1961         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1962         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1963         with patch_user_cache_dir:
1964             assert get_cache_dir() == workspace1
1965
1966         # If it is set, use the path provided in the env var.
1967         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1968         assert get_cache_dir() == workspace2
1969
1970     def test_cache_broken_file(self) -> None:
1971         mode = DEFAULT_MODE
1972         with cache_dir() as workspace:
1973             cache_file = get_cache_file(mode)
1974             cache_file.write_text("this is not a pickle", encoding="utf-8")
1975             assert black.Cache.read(mode).file_data == {}
1976             src = (workspace / "test.py").resolve()
1977             src.write_text("print('hello')", encoding="utf-8")
1978             invokeBlack([str(src)])
1979             cache = black.Cache.read(mode)
1980             assert not cache.is_changed(src)
1981
1982     def test_cache_single_file_already_cached(self) -> None:
1983         mode = DEFAULT_MODE
1984         with cache_dir() as workspace:
1985             src = (workspace / "test.py").resolve()
1986             src.write_text("print('hello')", encoding="utf-8")
1987             cache = black.Cache.read(mode)
1988             cache.write([src])
1989             invokeBlack([str(src)])
1990             assert src.read_text(encoding="utf-8") == "print('hello')"
1991
1992     @event_loop()
1993     def test_cache_multiple_files(self) -> None:
1994         mode = DEFAULT_MODE
1995         with cache_dir() as workspace, patch(
1996             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1997         ):
1998             one = (workspace / "one.py").resolve()
1999             one.write_text("print('hello')", encoding="utf-8")
2000             two = (workspace / "two.py").resolve()
2001             two.write_text("print('hello')", encoding="utf-8")
2002             cache = black.Cache.read(mode)
2003             cache.write([one])
2004             invokeBlack([str(workspace)])
2005             assert one.read_text(encoding="utf-8") == "print('hello')"
2006             assert two.read_text(encoding="utf-8") == 'print("hello")\n'
2007             cache = black.Cache.read(mode)
2008             assert not cache.is_changed(one)
2009             assert not cache.is_changed(two)
2010
2011     @pytest.mark.incompatible_with_mypyc
2012     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
2013     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
2014         mode = DEFAULT_MODE
2015         with cache_dir() as workspace:
2016             src = (workspace / "test.py").resolve()
2017             src.write_text("print('hello')", encoding="utf-8")
2018             with patch.object(black.Cache, "read") as read_cache, patch.object(
2019                 black.Cache, "write"
2020             ) as write_cache:
2021                 cmd = [str(src), "--diff"]
2022                 if color:
2023                     cmd.append("--color")
2024                 invokeBlack(cmd)
2025                 cache_file = get_cache_file(mode)
2026                 assert cache_file.exists() is False
2027                 read_cache.assert_called_once()
2028                 write_cache.assert_not_called()
2029
2030     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
2031     @event_loop()
2032     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
2033         with cache_dir() as workspace:
2034             for tag in range(0, 4):
2035                 src = (workspace / f"test{tag}.py").resolve()
2036                 src.write_text("print('hello')", encoding="utf-8")
2037             with patch(
2038                 "black.concurrency.Manager", wraps=multiprocessing.Manager
2039             ) as mgr:
2040                 cmd = ["--diff", str(workspace)]
2041                 if color:
2042                     cmd.append("--color")
2043                 invokeBlack(cmd, exit_code=0)
2044                 # this isn't quite doing what we want, but if it _isn't_
2045                 # called then we cannot be using the lock it provides
2046                 mgr.assert_called()
2047
2048     def test_no_cache_when_stdin(self) -> None:
2049         mode = DEFAULT_MODE
2050         with cache_dir():
2051             result = CliRunner().invoke(
2052                 black.main, ["-"], input=BytesIO(b"print('hello')")
2053             )
2054             assert not result.exit_code
2055             cache_file = get_cache_file(mode)
2056             assert not cache_file.exists()
2057
2058     def test_read_cache_no_cachefile(self) -> None:
2059         mode = DEFAULT_MODE
2060         with cache_dir():
2061             assert black.Cache.read(mode).file_data == {}
2062
2063     def test_write_cache_read_cache(self) -> None:
2064         mode = DEFAULT_MODE
2065         with cache_dir() as workspace:
2066             src = (workspace / "test.py").resolve()
2067             src.touch()
2068             write_cache = black.Cache.read(mode)
2069             write_cache.write([src])
2070             read_cache = black.Cache.read(mode)
2071             assert not read_cache.is_changed(src)
2072
2073     @pytest.mark.incompatible_with_mypyc
2074     def test_filter_cached(self) -> None:
2075         with TemporaryDirectory() as workspace:
2076             path = Path(workspace)
2077             uncached = (path / "uncached").resolve()
2078             cached = (path / "cached").resolve()
2079             cached_but_changed = (path / "changed").resolve()
2080             uncached.touch()
2081             cached.touch()
2082             cached_but_changed.touch()
2083             cache = black.Cache.read(DEFAULT_MODE)
2084
2085             orig_func = black.Cache.get_file_data
2086
2087             def wrapped_func(path: Path) -> FileData:
2088                 if path == cached:
2089                     return orig_func(path)
2090                 if path == cached_but_changed:
2091                     return FileData(0.0, 0, "")
2092                 raise AssertionError
2093
2094             with patch.object(black.Cache, "get_file_data", side_effect=wrapped_func):
2095                 cache.write([cached, cached_but_changed])
2096             todo, done = cache.filtered_cached({uncached, cached, cached_but_changed})
2097             assert todo == {uncached, cached_but_changed}
2098             assert done == {cached}
2099
2100     def test_filter_cached_hash(self) -> None:
2101         with TemporaryDirectory() as workspace:
2102             path = Path(workspace)
2103             src = (path / "test.py").resolve()
2104             src.write_text("print('hello')", encoding="utf-8")
2105             st = src.stat()
2106             cache = black.Cache.read(DEFAULT_MODE)
2107             cache.write([src])
2108             cached_file_data = cache.file_data[str(src)]
2109
2110             todo, done = cache.filtered_cached([src])
2111             assert todo == set()
2112             assert done == {src}
2113             assert cached_file_data.st_mtime == st.st_mtime
2114
2115             # Modify st_mtime
2116             cached_file_data = cache.file_data[str(src)] = FileData(
2117                 cached_file_data.st_mtime - 1,
2118                 cached_file_data.st_size,
2119                 cached_file_data.hash,
2120             )
2121             todo, done = cache.filtered_cached([src])
2122             assert todo == set()
2123             assert done == {src}
2124             assert cached_file_data.st_mtime < st.st_mtime
2125             assert cached_file_data.st_size == st.st_size
2126             assert cached_file_data.hash == black.Cache.hash_digest(src)
2127
2128             # Modify contents
2129             src.write_text("print('hello world')", encoding="utf-8")
2130             new_st = src.stat()
2131             todo, done = cache.filtered_cached([src])
2132             assert todo == {src}
2133             assert done == set()
2134             assert cached_file_data.st_mtime < new_st.st_mtime
2135             assert cached_file_data.st_size != new_st.st_size
2136             assert cached_file_data.hash != black.Cache.hash_digest(src)
2137
2138     def test_write_cache_creates_directory_if_needed(self) -> None:
2139         mode = DEFAULT_MODE
2140         with cache_dir(exists=False) as workspace:
2141             assert not workspace.exists()
2142             cache = black.Cache.read(mode)
2143             cache.write([])
2144             assert workspace.exists()
2145
2146     @event_loop()
2147     def test_failed_formatting_does_not_get_cached(self) -> None:
2148         mode = DEFAULT_MODE
2149         with cache_dir() as workspace, patch(
2150             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
2151         ):
2152             failing = (workspace / "failing.py").resolve()
2153             failing.write_text("not actually python", encoding="utf-8")
2154             clean = (workspace / "clean.py").resolve()
2155             clean.write_text('print("hello")\n', encoding="utf-8")
2156             invokeBlack([str(workspace)], exit_code=123)
2157             cache = black.Cache.read(mode)
2158             assert cache.is_changed(failing)
2159             assert not cache.is_changed(clean)
2160
2161     def test_write_cache_write_fail(self) -> None:
2162         mode = DEFAULT_MODE
2163         with cache_dir():
2164             cache = black.Cache.read(mode)
2165             with patch.object(Path, "open") as mock:
2166                 mock.side_effect = OSError
2167                 cache.write([])
2168
2169     def test_read_cache_line_lengths(self) -> None:
2170         mode = DEFAULT_MODE
2171         short_mode = replace(DEFAULT_MODE, line_length=1)
2172         with cache_dir() as workspace:
2173             path = (workspace / "file.py").resolve()
2174             path.touch()
2175             cache = black.Cache.read(mode)
2176             cache.write([path])
2177             one = black.Cache.read(mode)
2178             assert not one.is_changed(path)
2179             two = black.Cache.read(short_mode)
2180             assert two.is_changed(path)
2181
2182
2183 def assert_collected_sources(
2184     src: Sequence[Union[str, Path]],
2185     expected: Sequence[Union[str, Path]],
2186     *,
2187     root: Optional[Path] = None,
2188     exclude: Optional[str] = None,
2189     include: Optional[str] = None,
2190     extend_exclude: Optional[str] = None,
2191     force_exclude: Optional[str] = None,
2192     stdin_filename: Optional[str] = None,
2193 ) -> None:
2194     gs_src = tuple(str(Path(s)) for s in src)
2195     gs_expected = [Path(s) for s in expected]
2196     gs_exclude = None if exclude is None else compile_pattern(exclude)
2197     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
2198     gs_extend_exclude = (
2199         None if extend_exclude is None else compile_pattern(extend_exclude)
2200     )
2201     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
2202     collected = black.get_sources(
2203         root=root or THIS_DIR,
2204         src=gs_src,
2205         quiet=False,
2206         verbose=False,
2207         include=gs_include,
2208         exclude=gs_exclude,
2209         extend_exclude=gs_extend_exclude,
2210         force_exclude=gs_force_exclude,
2211         report=black.Report(),
2212         stdin_filename=stdin_filename,
2213     )
2214     assert sorted(collected) == sorted(gs_expected)
2215
2216
2217 class TestFileCollection:
2218     def test_include_exclude(self) -> None:
2219         path = THIS_DIR / "data" / "include_exclude_tests"
2220         src = [path]
2221         expected = [
2222             Path(path / "b/dont_exclude/a.py"),
2223             Path(path / "b/dont_exclude/a.pyi"),
2224         ]
2225         assert_collected_sources(
2226             src,
2227             expected,
2228             include=r"\.pyi?$",
2229             exclude=r"/exclude/|/\.definitely_exclude/",
2230         )
2231
2232     def test_gitignore_used_as_default(self) -> None:
2233         base = Path(DATA_DIR / "include_exclude_tests")
2234         expected = [
2235             base / "b/.definitely_exclude/a.py",
2236             base / "b/.definitely_exclude/a.pyi",
2237         ]
2238         src = [base / "b/"]
2239         assert_collected_sources(src, expected, root=base, extend_exclude=r"/exclude/")
2240
2241     def test_gitignore_used_on_multiple_sources(self) -> None:
2242         root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
2243         expected = [
2244             root / "dir1" / "b.py",
2245             root / "dir2" / "b.py",
2246         ]
2247         src = [root / "dir1", root / "dir2"]
2248         assert_collected_sources(src, expected, root=root)
2249
2250     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2251     def test_exclude_for_issue_1572(self) -> None:
2252         # Exclude shouldn't touch files that were explicitly given to Black through the
2253         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
2254         # https://github.com/psf/black/issues/1572
2255         path = DATA_DIR / "include_exclude_tests"
2256         src = [path / "b/exclude/a.py"]
2257         expected = [path / "b/exclude/a.py"]
2258         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2259
2260     def test_gitignore_exclude(self) -> None:
2261         path = THIS_DIR / "data" / "include_exclude_tests"
2262         include = re.compile(r"\.pyi?$")
2263         exclude = re.compile(r"")
2264         report = black.Report()
2265         gitignore = PathSpec.from_lines(
2266             "gitwildmatch", ["exclude/", ".definitely_exclude"]
2267         )
2268         sources: List[Path] = []
2269         expected = [
2270             Path(path / "b/dont_exclude/a.py"),
2271             Path(path / "b/dont_exclude/a.pyi"),
2272         ]
2273         this_abs = THIS_DIR.resolve()
2274         sources.extend(
2275             black.gen_python_files(
2276                 path.iterdir(),
2277                 this_abs,
2278                 include,
2279                 exclude,
2280                 None,
2281                 None,
2282                 report,
2283                 {path: gitignore},
2284                 verbose=False,
2285                 quiet=False,
2286             )
2287         )
2288         assert sorted(expected) == sorted(sources)
2289
2290     def test_nested_gitignore(self) -> None:
2291         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
2292         include = re.compile(r"\.pyi?$")
2293         exclude = re.compile(r"")
2294         root_gitignore = black.files.get_gitignore(path)
2295         report = black.Report()
2296         expected: List[Path] = [
2297             Path(path / "x.py"),
2298             Path(path / "root/b.py"),
2299             Path(path / "root/c.py"),
2300             Path(path / "root/child/c.py"),
2301         ]
2302         this_abs = THIS_DIR.resolve()
2303         sources = list(
2304             black.gen_python_files(
2305                 path.iterdir(),
2306                 this_abs,
2307                 include,
2308                 exclude,
2309                 None,
2310                 None,
2311                 report,
2312                 {path: root_gitignore},
2313                 verbose=False,
2314                 quiet=False,
2315             )
2316         )
2317         assert sorted(expected) == sorted(sources)
2318
2319     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2320         # https://github.com/psf/black/issues/2598
2321         path = Path(DATA_DIR / "nested_gitignore_tests")
2322         src = Path(path / "root" / "child")
2323         expected = [src / "a.py", src / "c.py"]
2324         assert_collected_sources([src], expected)
2325
2326     def test_invalid_gitignore(self) -> None:
2327         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2328         empty_config = path / "pyproject.toml"
2329         result = BlackRunner().invoke(
2330             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2331         )
2332         assert result.exit_code == 1
2333         assert result.stderr_bytes is not None
2334
2335         gitignore = path / ".gitignore"
2336         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2337
2338     def test_invalid_nested_gitignore(self) -> None:
2339         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2340         empty_config = path / "pyproject.toml"
2341         result = BlackRunner().invoke(
2342             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2343         )
2344         assert result.exit_code == 1
2345         assert result.stderr_bytes is not None
2346
2347         gitignore = path / "a" / ".gitignore"
2348         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2349
2350     def test_gitignore_that_ignores_subfolders(self) -> None:
2351         # If gitignore with */* is in root
2352         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
2353         expected = [root / "b.py"]
2354         assert_collected_sources([root], expected, root=root)
2355
2356         # If .gitignore with */* is nested
2357         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2358         expected = [
2359             root / "a.py",
2360             root / "subdir" / "b.py",
2361         ]
2362         assert_collected_sources([root], expected, root=root)
2363
2364         # If command is executed from outer dir
2365         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2366         target = root / "subdir"
2367         expected = [target / "b.py"]
2368         assert_collected_sources([target], expected, root=root)
2369
2370     def test_empty_include(self) -> None:
2371         path = DATA_DIR / "include_exclude_tests"
2372         src = [path]
2373         expected = [
2374             Path(path / "b/exclude/a.pie"),
2375             Path(path / "b/exclude/a.py"),
2376             Path(path / "b/exclude/a.pyi"),
2377             Path(path / "b/dont_exclude/a.pie"),
2378             Path(path / "b/dont_exclude/a.py"),
2379             Path(path / "b/dont_exclude/a.pyi"),
2380             Path(path / "b/.definitely_exclude/a.pie"),
2381             Path(path / "b/.definitely_exclude/a.py"),
2382             Path(path / "b/.definitely_exclude/a.pyi"),
2383             Path(path / ".gitignore"),
2384             Path(path / "pyproject.toml"),
2385         ]
2386         # Setting exclude explicitly to an empty string to block .gitignore usage.
2387         assert_collected_sources(src, expected, include="", exclude="")
2388
2389     def test_extend_exclude(self) -> None:
2390         path = DATA_DIR / "include_exclude_tests"
2391         src = [path]
2392         expected = [
2393             Path(path / "b/exclude/a.py"),
2394             Path(path / "b/dont_exclude/a.py"),
2395         ]
2396         assert_collected_sources(
2397             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2398         )
2399
2400     @pytest.mark.incompatible_with_mypyc
2401     def test_symlinks(self) -> None:
2402         path = MagicMock()
2403         root = THIS_DIR.resolve()
2404         include = re.compile(black.DEFAULT_INCLUDES)
2405         exclude = re.compile(black.DEFAULT_EXCLUDES)
2406         report = black.Report()
2407         gitignore = PathSpec.from_lines("gitwildmatch", [])
2408
2409         regular = MagicMock()
2410         outside_root_symlink = MagicMock()
2411         ignored_symlink = MagicMock()
2412
2413         path.iterdir.return_value = [regular, outside_root_symlink, ignored_symlink]
2414
2415         regular.absolute.return_value = root / "regular.py"
2416         regular.resolve.return_value = root / "regular.py"
2417         regular.is_dir.return_value = False
2418
2419         outside_root_symlink.absolute.return_value = root / "symlink.py"
2420         outside_root_symlink.resolve.return_value = Path("/nowhere")
2421
2422         ignored_symlink.absolute.return_value = root / ".mypy_cache" / "symlink.py"
2423
2424         files = list(
2425             black.gen_python_files(
2426                 path.iterdir(),
2427                 root,
2428                 include,
2429                 exclude,
2430                 None,
2431                 None,
2432                 report,
2433                 {path: gitignore},
2434                 verbose=False,
2435                 quiet=False,
2436             )
2437         )
2438         assert files == [regular]
2439
2440         path.iterdir.assert_called_once()
2441         outside_root_symlink.resolve.assert_called_once()
2442         ignored_symlink.resolve.assert_not_called()
2443
2444     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2445     def test_get_sources_with_stdin(self) -> None:
2446         src = ["-"]
2447         expected = ["-"]
2448         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2449
2450     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2451     def test_get_sources_with_stdin_filename(self) -> None:
2452         src = ["-"]
2453         stdin_filename = str(THIS_DIR / "data/collections.py")
2454         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2455         assert_collected_sources(
2456             src,
2457             expected,
2458             exclude=r"/exclude/a\.py",
2459             stdin_filename=stdin_filename,
2460         )
2461
2462     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2463     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2464         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2465         # file being passed directly. This is the same as
2466         # test_exclude_for_issue_1572
2467         path = DATA_DIR / "include_exclude_tests"
2468         src = ["-"]
2469         stdin_filename = str(path / "b/exclude/a.py")
2470         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2471         assert_collected_sources(
2472             src,
2473             expected,
2474             exclude=r"/exclude/|a\.py",
2475             stdin_filename=stdin_filename,
2476         )
2477
2478     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2479     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2480         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2481         # file being passed directly. This is the same as
2482         # test_exclude_for_issue_1572
2483         src = ["-"]
2484         path = THIS_DIR / "data" / "include_exclude_tests"
2485         stdin_filename = str(path / "b/exclude/a.py")
2486         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2487         assert_collected_sources(
2488             src,
2489             expected,
2490             extend_exclude=r"/exclude/|a\.py",
2491             stdin_filename=stdin_filename,
2492         )
2493
2494     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2495     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2496         # Force exclude should exclude the file when passing it through
2497         # stdin_filename
2498         path = THIS_DIR / "data" / "include_exclude_tests"
2499         stdin_filename = str(path / "b/exclude/a.py")
2500         assert_collected_sources(
2501             src=["-"],
2502             expected=[],
2503             force_exclude=r"/exclude/|a\.py",
2504             stdin_filename=stdin_filename,
2505         )
2506
2507
2508 class TestDeFactoAPI:
2509     """Test that certain symbols that are commonly used externally keep working.
2510
2511     We don't (yet) formally expose an API (see issue #779), but we should endeavor to
2512     keep certain functions that external users commonly rely on working.
2513
2514     """
2515
2516     def test_format_str(self) -> None:
2517         # format_str and Mode should keep working
2518         assert (
2519             black.format_str("print('hello')", mode=black.Mode()) == 'print("hello")\n'
2520         )
2521
2522         # you can pass line length
2523         assert (
2524             black.format_str("print('hello')", mode=black.Mode(line_length=42))
2525             == 'print("hello")\n'
2526         )
2527
2528         # invalid input raises InvalidInput
2529         with pytest.raises(black.InvalidInput):
2530             black.format_str("syntax error", mode=black.Mode())
2531
2532     def test_format_file_contents(self) -> None:
2533         # You probably should be using format_str() instead, but let's keep
2534         # this one around since people do use it
2535         assert (
2536             black.format_file_contents("x=1", fast=True, mode=black.Mode()) == "x = 1\n"
2537         )
2538
2539         with pytest.raises(black.NothingChanged):
2540             black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
2541
2542
2543 try:
2544     with open(black.__file__, "r", encoding="utf-8") as _bf:
2545         black_source_lines = _bf.readlines()
2546 except UnicodeDecodeError:
2547     if not black.COMPILED:
2548         raise
2549
2550
2551 def tracefunc(
2552     frame: types.FrameType, event: str, arg: Any
2553 ) -> Callable[[types.FrameType, str, Any], Any]:
2554     """Show function calls `from black/__init__.py` as they happen.
2555
2556     Register this with `sys.settrace()` in a test you're debugging.
2557     """
2558     if event != "call":
2559         return tracefunc
2560
2561     stack = len(inspect.stack()) - 19
2562     stack *= 2
2563     filename = frame.f_code.co_filename
2564     lineno = frame.f_lineno
2565     func_sig_lineno = lineno - 1
2566     funcname = black_source_lines[func_sig_lineno].strip()
2567     while funcname.startswith("@"):
2568         func_sig_lineno += 1
2569         funcname = black_source_lines[func_sig_lineno].strip()
2570     if "black/__init__.py" in filename:
2571         print(f"{' ' * stack}{lineno}:{funcname}")
2572     return tracefunc