]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump docker/login-action from 2 to 3 (#3891)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     Type,
29     TypeVar,
30     Union,
31 )
32 from unittest.mock import MagicMock, patch
33
34 import click
35 import pytest
36 from click import unstyle
37 from click.testing import CliRunner
38 from pathspec import PathSpec
39
40 import black
41 import black.files
42 from black import Feature, TargetVersion
43 from black import re_compile_maybe_verbose as compile_pattern
44 from black.cache import FileData, get_cache_dir, get_cache_file
45 from black.debug import DebugVisitor
46 from black.output import color_diff, diff
47 from black.report import Report
48
49 # Import other test classes
50 from tests.util import (
51     DATA_DIR,
52     DEFAULT_MODE,
53     DETERMINISTIC_HEADER,
54     PROJECT_ROOT,
55     PY36_VERSIONS,
56     THIS_DIR,
57     BlackBaseTestCase,
58     assert_format,
59     change_directory,
60     dump_to_stderr,
61     ff,
62     fs,
63     get_case_path,
64     read_data,
65     read_data_from_file,
66 )
67
68 THIS_FILE = Path(__file__)
69 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
70 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
71 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
72 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
73 T = TypeVar("T")
74 R = TypeVar("R")
75
76 # Match the time output in a diff, but nothing else
77 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.cache.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop() -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     loop = policy.new_event_loop()
94     asyncio.set_event_loop(loop)
95     try:
96         yield
97
98     finally:
99         loop.close()
100
101
102 class FakeContext(click.Context):
103     """A fake click Context for when calling functions that need it."""
104
105     def __init__(self) -> None:
106         self.default_map: Dict[str, Any] = {}
107         self.params: Dict[str, Any] = {}
108         # Dummy root, since most of the tests don't care about it
109         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
110
111
112 class FakeParameter(click.Parameter):
113     """A fake click Parameter for when calling functions that need it."""
114
115     def __init__(self) -> None:
116         pass
117
118
119 class BlackRunner(CliRunner):
120     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
121
122     def __init__(self) -> None:
123         super().__init__(mix_stderr=False)
124
125
126 def invokeBlack(
127     args: List[str], exit_code: int = 0, ignore_config: bool = True
128 ) -> None:
129     runner = BlackRunner()
130     if ignore_config:
131         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
132     result = runner.invoke(black.main, args, catch_exceptions=False)
133     assert result.stdout_bytes is not None
134     assert result.stderr_bytes is not None
135     msg = (
136         f"Failed with args: {args}\n"
137         f"stdout: {result.stdout_bytes.decode()!r}\n"
138         f"stderr: {result.stderr_bytes.decode()!r}\n"
139         f"exception: {result.exception}"
140     )
141     assert result.exit_code == exit_code, msg
142
143
144 class BlackTestCase(BlackBaseTestCase):
145     invokeBlack = staticmethod(invokeBlack)
146
147     def test_empty_ff(self) -> None:
148         expected = ""
149         tmp_file = Path(black.dump_to_file())
150         try:
151             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
152             actual = tmp_file.read_text(encoding="utf-8")
153         finally:
154             os.unlink(tmp_file)
155         self.assertFormatEqual(expected, actual)
156
157     @patch("black.dump_to_file", dump_to_stderr)
158     def test_one_empty_line(self) -> None:
159         mode = black.Mode(preview=True)
160         for nl in ["\n", "\r\n"]:
161             source = expected = nl
162             assert_format(source, expected, mode=mode)
163
164     def test_one_empty_line_ff(self) -> None:
165         mode = black.Mode(preview=True)
166         for nl in ["\n", "\r\n"]:
167             expected = nl
168             tmp_file = Path(black.dump_to_file(nl))
169             if system() == "Windows":
170                 # Writing files in text mode automatically uses the system newline,
171                 # but in this case we don't want this for testing reasons. See:
172                 # https://github.com/psf/black/pull/3348
173                 with open(tmp_file, "wb") as f:
174                     f.write(nl.encode("utf-8"))
175             try:
176                 self.assertFalse(
177                     ff(tmp_file, mode=mode, write_back=black.WriteBack.YES)
178                 )
179                 with open(tmp_file, "rb") as f:
180                     actual = f.read().decode("utf-8")
181             finally:
182                 os.unlink(tmp_file)
183             self.assertFormatEqual(expected, actual)
184
185     def test_experimental_string_processing_warns(self) -> None:
186         self.assertWarns(
187             black.mode.Deprecated, black.Mode, experimental_string_processing=True
188         )
189
190     def test_piping(self) -> None:
191         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
192         result = BlackRunner().invoke(
193             black.main,
194             [
195                 "-",
196                 "--fast",
197                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
198                 f"--config={EMPTY_CONFIG}",
199             ],
200             input=BytesIO(source.encode("utf-8")),
201         )
202         self.assertEqual(result.exit_code, 0)
203         self.assertFormatEqual(expected, result.output)
204         if source != result.output:
205             black.assert_equivalent(source, result.output)
206             black.assert_stable(source, result.output, DEFAULT_MODE)
207
208     def test_piping_diff(self) -> None:
209         diff_header = re.compile(
210             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d"
211             r"\+\d\d:\d\d"
212         )
213         source, _ = read_data("simple_cases", "expression.py")
214         expected, _ = read_data("simple_cases", "expression.diff")
215         args = [
216             "-",
217             "--fast",
218             f"--line-length={black.DEFAULT_LINE_LENGTH}",
219             "--diff",
220             f"--config={EMPTY_CONFIG}",
221         ]
222         result = BlackRunner().invoke(
223             black.main, args, input=BytesIO(source.encode("utf-8"))
224         )
225         self.assertEqual(result.exit_code, 0)
226         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
227         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
228         self.assertEqual(expected, actual)
229
230     def test_piping_diff_with_color(self) -> None:
231         source, _ = read_data("simple_cases", "expression.py")
232         args = [
233             "-",
234             "--fast",
235             f"--line-length={black.DEFAULT_LINE_LENGTH}",
236             "--diff",
237             "--color",
238             f"--config={EMPTY_CONFIG}",
239         ]
240         result = BlackRunner().invoke(
241             black.main, args, input=BytesIO(source.encode("utf-8"))
242         )
243         actual = result.output
244         # Again, the contents are checked in a different test, so only look for colors.
245         self.assertIn("\033[1m", actual)
246         self.assertIn("\033[36m", actual)
247         self.assertIn("\033[32m", actual)
248         self.assertIn("\033[31m", actual)
249         self.assertIn("\033[0m", actual)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def _test_wip(self) -> None:
253         source, expected = read_data("miscellaneous", "wip")
254         sys.settrace(tracefunc)
255         mode = replace(
256             DEFAULT_MODE,
257             experimental_string_processing=False,
258             target_versions={black.TargetVersion.PY38},
259         )
260         actual = fs(source, mode=mode)
261         sys.settrace(None)
262         self.assertFormatEqual(expected, actual)
263         black.assert_equivalent(source, actual)
264         black.assert_stable(source, actual, black.FileMode())
265
266     def test_pep_572_version_detection(self) -> None:
267         source, _ = read_data("py_38", "pep_572")
268         root = black.lib2to3_parse(source)
269         features = black.get_features_used(root)
270         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
271         versions = black.detect_target_versions(root)
272         self.assertIn(black.TargetVersion.PY38, versions)
273
274     def test_pep_695_version_detection(self) -> None:
275         for file in ("type_aliases", "type_params"):
276             source, _ = read_data("py_312", file)
277             root = black.lib2to3_parse(source)
278             features = black.get_features_used(root)
279             self.assertIn(black.Feature.TYPE_PARAMS, features)
280             versions = black.detect_target_versions(root)
281             self.assertIn(black.TargetVersion.PY312, versions)
282
283     def test_expression_ff(self) -> None:
284         source, expected = read_data("simple_cases", "expression.py")
285         tmp_file = Path(black.dump_to_file(source))
286         try:
287             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
288             actual = tmp_file.read_text(encoding="utf-8")
289         finally:
290             os.unlink(tmp_file)
291         self.assertFormatEqual(expected, actual)
292         with patch("black.dump_to_file", dump_to_stderr):
293             black.assert_equivalent(source, actual)
294             black.assert_stable(source, actual, DEFAULT_MODE)
295
296     def test_expression_diff(self) -> None:
297         source, _ = read_data("simple_cases", "expression.py")
298         expected, _ = read_data("simple_cases", "expression.diff")
299         tmp_file = Path(black.dump_to_file(source))
300         diff_header = re.compile(
301             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
302             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
303         )
304         try:
305             result = BlackRunner().invoke(
306                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
307             )
308             self.assertEqual(result.exit_code, 0)
309         finally:
310             os.unlink(tmp_file)
311         actual = result.output
312         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
313         if expected != actual:
314             dump = black.dump_to_file(actual)
315             msg = (
316                 "Expected diff isn't equal to the actual. If you made changes to"
317                 " expression.py and this is an anticipated difference, overwrite"
318                 f" tests/data/expression.diff with {dump}"
319             )
320             self.assertEqual(expected, actual, msg)
321
322     def test_expression_diff_with_color(self) -> None:
323         source, _ = read_data("simple_cases", "expression.py")
324         expected, _ = read_data("simple_cases", "expression.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         try:
327             result = BlackRunner().invoke(
328                 black.main,
329                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
330             )
331         finally:
332             os.unlink(tmp_file)
333         actual = result.output
334         # We check the contents of the diff in `test_expression_diff`. All
335         # we need to check here is that color codes exist in the result.
336         self.assertIn("\033[1m", actual)
337         self.assertIn("\033[36m", actual)
338         self.assertIn("\033[32m", actual)
339         self.assertIn("\033[31m", actual)
340         self.assertIn("\033[0m", actual)
341
342     def test_detect_pos_only_arguments(self) -> None:
343         source, _ = read_data("py_38", "pep_570")
344         root = black.lib2to3_parse(source)
345         features = black.get_features_used(root)
346         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
347         versions = black.detect_target_versions(root)
348         self.assertIn(black.TargetVersion.PY38, versions)
349
350     def test_detect_debug_f_strings(self) -> None:
351         root = black.lib2to3_parse("""f"{x=}" """)
352         features = black.get_features_used(root)
353         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
354         versions = black.detect_target_versions(root)
355         self.assertIn(black.TargetVersion.PY38, versions)
356
357         root = black.lib2to3_parse(
358             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
359         )
360         features = black.get_features_used(root)
361         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
362
363         # We don't yet support feature version detection in nested f-strings
364         root = black.lib2to3_parse(
365             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
366         )
367         features = black.get_features_used(root)
368         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
369
370     @patch("black.dump_to_file", dump_to_stderr)
371     def test_string_quotes(self) -> None:
372         source, expected = read_data("miscellaneous", "string_quotes")
373         mode = black.Mode(preview=True)
374         assert_format(source, expected, mode)
375         mode = replace(mode, string_normalization=False)
376         not_normalized = fs(source, mode=mode)
377         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
378         black.assert_equivalent(source, not_normalized)
379         black.assert_stable(source, not_normalized, mode=mode)
380
381     def test_skip_source_first_line(self) -> None:
382         source, _ = read_data("miscellaneous", "invalid_header")
383         tmp_file = Path(black.dump_to_file(source))
384         # Full source should fail (invalid syntax at header)
385         self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
386         # So, skipping the first line should work
387         result = BlackRunner().invoke(
388             black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
389         )
390         self.assertEqual(result.exit_code, 0)
391         actual = tmp_file.read_text(encoding="utf-8")
392         self.assertFormatEqual(source, actual)
393
394     def test_skip_source_first_line_when_mixing_newlines(self) -> None:
395         code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
396         expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
397         with TemporaryDirectory() as workspace:
398             test_file = Path(workspace) / "skip_header.py"
399             test_file.write_bytes(code_mixing_newlines)
400             mode = replace(DEFAULT_MODE, skip_source_first_line=True)
401             ff(test_file, mode=mode, write_back=black.WriteBack.YES)
402             self.assertEqual(test_file.read_bytes(), expected)
403
404     def test_skip_magic_trailing_comma(self) -> None:
405         source, _ = read_data("simple_cases", "expression")
406         expected, _ = read_data(
407             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
408         )
409         tmp_file = Path(black.dump_to_file(source))
410         diff_header = re.compile(
411             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
412             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
413         )
414         try:
415             result = BlackRunner().invoke(
416                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
417             )
418             self.assertEqual(result.exit_code, 0)
419         finally:
420             os.unlink(tmp_file)
421         actual = result.output
422         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
423         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
424         if expected != actual:
425             dump = black.dump_to_file(actual)
426             msg = (
427                 "Expected diff isn't equal to the actual. If you made changes to"
428                 " expression.py and this is an anticipated difference, overwrite"
429                 " tests/data/miscellaneous/expression_skip_magic_trailing_comma.diff"
430                 f" with {dump}"
431             )
432             self.assertEqual(expected, actual, msg)
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_async_as_identifier(self) -> None:
436         source_path = get_case_path("miscellaneous", "async_as_identifier")
437         source, expected = read_data_from_file(source_path)
438         actual = fs(source)
439         self.assertFormatEqual(expected, actual)
440         major, minor = sys.version_info[:2]
441         if major < 3 or (major <= 3 and minor < 7):
442             black.assert_equivalent(source, actual)
443         black.assert_stable(source, actual, DEFAULT_MODE)
444         # ensure black can parse this when the target is 3.6
445         self.invokeBlack([str(source_path), "--target-version", "py36"])
446         # but not on 3.7, because async/await is no longer an identifier
447         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_python37(self) -> None:
451         source_path = get_case_path("py_37", "python37")
452         source, expected = read_data_from_file(source_path)
453         actual = fs(source)
454         self.assertFormatEqual(expected, actual)
455         major, minor = sys.version_info[:2]
456         if major > 3 or (major == 3 and minor >= 7):
457             black.assert_equivalent(source, actual)
458         black.assert_stable(source, actual, DEFAULT_MODE)
459         # ensure black can parse this when the target is 3.7
460         self.invokeBlack([str(source_path), "--target-version", "py37"])
461         # but not on 3.6, because we use async as a reserved keyword
462         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
463
464     def test_tab_comment_indentation(self) -> None:
465         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
466         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
467         self.assertFormatEqual(contents_spc, fs(contents_spc))
468         self.assertFormatEqual(contents_spc, fs(contents_tab))
469
470         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
471         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
472         self.assertFormatEqual(contents_spc, fs(contents_spc))
473         self.assertFormatEqual(contents_spc, fs(contents_tab))
474
475         # mixed tabs and spaces (valid Python 2 code)
476         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
477         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
478         self.assertFormatEqual(contents_spc, fs(contents_spc))
479         self.assertFormatEqual(contents_spc, fs(contents_tab))
480
481         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
482         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
483         self.assertFormatEqual(contents_spc, fs(contents_spc))
484         self.assertFormatEqual(contents_spc, fs(contents_tab))
485
486     def test_false_positive_symlink_output_issue_3384(self) -> None:
487         # Emulate the behavior when using the CLI (`black ./child  --verbose`), which
488         # involves patching some `pathlib.Path` methods. In particular, `is_dir` is
489         # patched only on its first call: when checking if "./child" is a directory it
490         # should return True. The "./child" folder exists relative to the cwd when
491         # running from CLI, but fails when running the tests because cwd is different
492         project_root = Path(THIS_DIR / "data" / "nested_gitignore_tests")
493         working_directory = project_root / "root"
494         target_abspath = working_directory / "child"
495         target_contents = list(target_abspath.iterdir())
496
497         def mock_n_calls(responses: List[bool]) -> Callable[[], bool]:
498             def _mocked_calls() -> bool:
499                 if responses:
500                     return responses.pop(0)
501                 return False
502
503             return _mocked_calls
504
505         with patch("pathlib.Path.iterdir", return_value=target_contents), patch(
506             "pathlib.Path.cwd", return_value=working_directory
507         ), patch("pathlib.Path.is_dir", side_effect=mock_n_calls([True])):
508             # Note that the root folder (project_root) isn't the folder
509             # named "root" (aka working_directory)
510             report = MagicMock(verbose=True)
511             black.get_sources(
512                 root=project_root,
513                 src=("./child",),
514                 quiet=False,
515                 verbose=True,
516                 include=DEFAULT_INCLUDE,
517                 exclude=None,
518                 report=report,
519                 extend_exclude=None,
520                 force_exclude=None,
521                 stdin_filename=None,
522             )
523         assert not any(
524             mock_args[1].startswith("is a symbolic link that points outside")
525             for _, mock_args, _ in report.path_ignored.mock_calls
526         ), "A symbolic link was reported."
527         report.path_ignored.assert_called_once_with(
528             Path("root", "child", "b.py"), "matches a .gitignore file content"
529         )
530
531     def test_report_verbose(self) -> None:
532         report = Report(verbose=True)
533         out_lines = []
534         err_lines = []
535
536         def out(msg: str, **kwargs: Any) -> None:
537             out_lines.append(msg)
538
539         def err(msg: str, **kwargs: Any) -> None:
540             err_lines.append(msg)
541
542         with patch("black.output._out", out), patch("black.output._err", err):
543             report.done(Path("f1"), black.Changed.NO)
544             self.assertEqual(len(out_lines), 1)
545             self.assertEqual(len(err_lines), 0)
546             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
547             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
548             self.assertEqual(report.return_code, 0)
549             report.done(Path("f2"), black.Changed.YES)
550             self.assertEqual(len(out_lines), 2)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(out_lines[-1], "reformatted f2")
553             self.assertEqual(
554                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
555             )
556             report.done(Path("f3"), black.Changed.CACHED)
557             self.assertEqual(len(out_lines), 3)
558             self.assertEqual(len(err_lines), 0)
559             self.assertEqual(
560                 out_lines[-1], "f3 wasn't modified on disk since last run."
561             )
562             self.assertEqual(
563                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
564             )
565             self.assertEqual(report.return_code, 0)
566             report.check = True
567             self.assertEqual(report.return_code, 1)
568             report.check = False
569             report.failed(Path("e1"), "boom")
570             self.assertEqual(len(out_lines), 3)
571             self.assertEqual(len(err_lines), 1)
572             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
573             self.assertEqual(
574                 unstyle(str(report)),
575                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
576                 " reformat.",
577             )
578             self.assertEqual(report.return_code, 123)
579             report.done(Path("f3"), black.Changed.YES)
580             self.assertEqual(len(out_lines), 4)
581             self.assertEqual(len(err_lines), 1)
582             self.assertEqual(out_lines[-1], "reformatted f3")
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.failed(Path("e2"), "boom")
590             self.assertEqual(len(out_lines), 4)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
596                 " reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.path_ignored(Path("wat"), "no match")
600             self.assertEqual(len(out_lines), 5)
601             self.assertEqual(len(err_lines), 2)
602             self.assertEqual(out_lines[-1], "wat ignored: no match")
603             self.assertEqual(
604                 unstyle(str(report)),
605                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
606                 " reformat.",
607             )
608             self.assertEqual(report.return_code, 123)
609             report.done(Path("f4"), black.Changed.NO)
610             self.assertEqual(len(out_lines), 6)
611             self.assertEqual(len(err_lines), 2)
612             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
613             self.assertEqual(
614                 unstyle(str(report)),
615                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
616                 " reformat.",
617             )
618             self.assertEqual(report.return_code, 123)
619             report.check = True
620             self.assertEqual(
621                 unstyle(str(report)),
622                 "2 files would be reformatted, 3 files would be left unchanged, 2"
623                 " files would fail to reformat.",
624             )
625             report.check = False
626             report.diff = True
627             self.assertEqual(
628                 unstyle(str(report)),
629                 "2 files would be reformatted, 3 files would be left unchanged, 2"
630                 " files would fail to reformat.",
631             )
632
633     def test_report_quiet(self) -> None:
634         report = Report(quiet=True)
635         out_lines = []
636         err_lines = []
637
638         def out(msg: str, **kwargs: Any) -> None:
639             out_lines.append(msg)
640
641         def err(msg: str, **kwargs: Any) -> None:
642             err_lines.append(msg)
643
644         with patch("black.output._out", out), patch("black.output._err", err):
645             report.done(Path("f1"), black.Changed.NO)
646             self.assertEqual(len(out_lines), 0)
647             self.assertEqual(len(err_lines), 0)
648             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
649             self.assertEqual(report.return_code, 0)
650             report.done(Path("f2"), black.Changed.YES)
651             self.assertEqual(len(out_lines), 0)
652             self.assertEqual(len(err_lines), 0)
653             self.assertEqual(
654                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
655             )
656             report.done(Path("f3"), black.Changed.CACHED)
657             self.assertEqual(len(out_lines), 0)
658             self.assertEqual(len(err_lines), 0)
659             self.assertEqual(
660                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
661             )
662             self.assertEqual(report.return_code, 0)
663             report.check = True
664             self.assertEqual(report.return_code, 1)
665             report.check = False
666             report.failed(Path("e1"), "boom")
667             self.assertEqual(len(out_lines), 0)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.done(Path("f3"), black.Changed.YES)
677             self.assertEqual(len(out_lines), 0)
678             self.assertEqual(len(err_lines), 1)
679             self.assertEqual(
680                 unstyle(str(report)),
681                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
682                 " reformat.",
683             )
684             self.assertEqual(report.return_code, 123)
685             report.failed(Path("e2"), "boom")
686             self.assertEqual(len(out_lines), 0)
687             self.assertEqual(len(err_lines), 2)
688             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
689             self.assertEqual(
690                 unstyle(str(report)),
691                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
692                 " reformat.",
693             )
694             self.assertEqual(report.return_code, 123)
695             report.path_ignored(Path("wat"), "no match")
696             self.assertEqual(len(out_lines), 0)
697             self.assertEqual(len(err_lines), 2)
698             self.assertEqual(
699                 unstyle(str(report)),
700                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
701                 " reformat.",
702             )
703             self.assertEqual(report.return_code, 123)
704             report.done(Path("f4"), black.Changed.NO)
705             self.assertEqual(len(out_lines), 0)
706             self.assertEqual(len(err_lines), 2)
707             self.assertEqual(
708                 unstyle(str(report)),
709                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
710                 " reformat.",
711             )
712             self.assertEqual(report.return_code, 123)
713             report.check = True
714             self.assertEqual(
715                 unstyle(str(report)),
716                 "2 files would be reformatted, 3 files would be left unchanged, 2"
717                 " files would fail to reformat.",
718             )
719             report.check = False
720             report.diff = True
721             self.assertEqual(
722                 unstyle(str(report)),
723                 "2 files would be reformatted, 3 files would be left unchanged, 2"
724                 " files would fail to reformat.",
725             )
726
727     def test_report_normal(self) -> None:
728         report = black.Report()
729         out_lines = []
730         err_lines = []
731
732         def out(msg: str, **kwargs: Any) -> None:
733             out_lines.append(msg)
734
735         def err(msg: str, **kwargs: Any) -> None:
736             err_lines.append(msg)
737
738         with patch("black.output._out", out), patch("black.output._err", err):
739             report.done(Path("f1"), black.Changed.NO)
740             self.assertEqual(len(out_lines), 0)
741             self.assertEqual(len(err_lines), 0)
742             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
743             self.assertEqual(report.return_code, 0)
744             report.done(Path("f2"), black.Changed.YES)
745             self.assertEqual(len(out_lines), 1)
746             self.assertEqual(len(err_lines), 0)
747             self.assertEqual(out_lines[-1], "reformatted f2")
748             self.assertEqual(
749                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
750             )
751             report.done(Path("f3"), black.Changed.CACHED)
752             self.assertEqual(len(out_lines), 1)
753             self.assertEqual(len(err_lines), 0)
754             self.assertEqual(out_lines[-1], "reformatted f2")
755             self.assertEqual(
756                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
757             )
758             self.assertEqual(report.return_code, 0)
759             report.check = True
760             self.assertEqual(report.return_code, 1)
761             report.check = False
762             report.failed(Path("e1"), "boom")
763             self.assertEqual(len(out_lines), 1)
764             self.assertEqual(len(err_lines), 1)
765             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
766             self.assertEqual(
767                 unstyle(str(report)),
768                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
769                 " reformat.",
770             )
771             self.assertEqual(report.return_code, 123)
772             report.done(Path("f3"), black.Changed.YES)
773             self.assertEqual(len(out_lines), 2)
774             self.assertEqual(len(err_lines), 1)
775             self.assertEqual(out_lines[-1], "reformatted f3")
776             self.assertEqual(
777                 unstyle(str(report)),
778                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
779                 " reformat.",
780             )
781             self.assertEqual(report.return_code, 123)
782             report.failed(Path("e2"), "boom")
783             self.assertEqual(len(out_lines), 2)
784             self.assertEqual(len(err_lines), 2)
785             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
786             self.assertEqual(
787                 unstyle(str(report)),
788                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
789                 " reformat.",
790             )
791             self.assertEqual(report.return_code, 123)
792             report.path_ignored(Path("wat"), "no match")
793             self.assertEqual(len(out_lines), 2)
794             self.assertEqual(len(err_lines), 2)
795             self.assertEqual(
796                 unstyle(str(report)),
797                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
798                 " reformat.",
799             )
800             self.assertEqual(report.return_code, 123)
801             report.done(Path("f4"), black.Changed.NO)
802             self.assertEqual(len(out_lines), 2)
803             self.assertEqual(len(err_lines), 2)
804             self.assertEqual(
805                 unstyle(str(report)),
806                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
807                 " reformat.",
808             )
809             self.assertEqual(report.return_code, 123)
810             report.check = True
811             self.assertEqual(
812                 unstyle(str(report)),
813                 "2 files would be reformatted, 3 files would be left unchanged, 2"
814                 " files would fail to reformat.",
815             )
816             report.check = False
817             report.diff = True
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files would be reformatted, 3 files would be left unchanged, 2"
821                 " files would fail to reformat.",
822             )
823
824     def test_lib2to3_parse(self) -> None:
825         with self.assertRaises(black.InvalidInput):
826             black.lib2to3_parse("invalid syntax")
827
828         straddling = "x + y"
829         black.lib2to3_parse(straddling)
830         black.lib2to3_parse(straddling, {TargetVersion.PY36})
831
832         py2_only = "print x"
833         with self.assertRaises(black.InvalidInput):
834             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
835
836         py3_only = "exec(x, end=y)"
837         black.lib2to3_parse(py3_only)
838         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
839
840     def test_get_features_used_decorator(self) -> None:
841         # Test the feature detection of new decorator syntax
842         # since this makes some test cases of test_get_features_used()
843         # fails if it fails, this is tested first so that a useful case
844         # is identified
845         simples, relaxed = read_data("miscellaneous", "decorators")
846         # skip explanation comments at the top of the file
847         for simple_test in simples.split("##")[1:]:
848             node = black.lib2to3_parse(simple_test)
849             decorator = str(node.children[0].children[0]).strip()
850             self.assertNotIn(
851                 Feature.RELAXED_DECORATORS,
852                 black.get_features_used(node),
853                 msg=(
854                     f"decorator '{decorator}' follows python<=3.8 syntax"
855                     "but is detected as 3.9+"
856                     # f"The full node is\n{node!r}"
857                 ),
858             )
859         # skip the '# output' comment at the top of the output part
860         for relaxed_test in relaxed.split("##")[1:]:
861             node = black.lib2to3_parse(relaxed_test)
862             decorator = str(node.children[0].children[0]).strip()
863             self.assertIn(
864                 Feature.RELAXED_DECORATORS,
865                 black.get_features_used(node),
866                 msg=(
867                     f"decorator '{decorator}' uses python3.9+ syntax"
868                     "but is detected as python<=3.8"
869                     # f"The full node is\n{node!r}"
870                 ),
871             )
872
873     def test_get_features_used(self) -> None:
874         node = black.lib2to3_parse("def f(*, arg): ...\n")
875         self.assertEqual(black.get_features_used(node), set())
876         node = black.lib2to3_parse("def f(*, arg,): ...\n")
877         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
878         node = black.lib2to3_parse("f(*arg,)\n")
879         self.assertEqual(
880             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
881         )
882         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
883         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
884         node = black.lib2to3_parse("123_456\n")
885         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
886         node = black.lib2to3_parse("123456\n")
887         self.assertEqual(black.get_features_used(node), set())
888         source, expected = read_data("simple_cases", "function")
889         node = black.lib2to3_parse(source)
890         expected_features = {
891             Feature.TRAILING_COMMA_IN_CALL,
892             Feature.TRAILING_COMMA_IN_DEF,
893             Feature.F_STRINGS,
894         }
895         self.assertEqual(black.get_features_used(node), expected_features)
896         node = black.lib2to3_parse(expected)
897         self.assertEqual(black.get_features_used(node), expected_features)
898         source, expected = read_data("simple_cases", "expression")
899         node = black.lib2to3_parse(source)
900         self.assertEqual(black.get_features_used(node), set())
901         node = black.lib2to3_parse(expected)
902         self.assertEqual(black.get_features_used(node), set())
903         node = black.lib2to3_parse("lambda a, /, b: ...")
904         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
905         node = black.lib2to3_parse("def fn(a, /, b): ...")
906         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
907         node = black.lib2to3_parse("def fn(): yield a, b")
908         self.assertEqual(black.get_features_used(node), set())
909         node = black.lib2to3_parse("def fn(): return a, b")
910         self.assertEqual(black.get_features_used(node), set())
911         node = black.lib2to3_parse("def fn(): yield *b, c")
912         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
913         node = black.lib2to3_parse("def fn(): return a, *b, c")
914         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
915         node = black.lib2to3_parse("x = a, *b, c")
916         self.assertEqual(black.get_features_used(node), set())
917         node = black.lib2to3_parse("x: Any = regular")
918         self.assertEqual(black.get_features_used(node), set())
919         node = black.lib2to3_parse("x: Any = (regular, regular)")
920         self.assertEqual(black.get_features_used(node), set())
921         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
922         self.assertEqual(black.get_features_used(node), set())
923         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
924         self.assertEqual(
925             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
926         )
927         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
928         self.assertEqual(black.get_features_used(node), set())
929         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
930         self.assertEqual(black.get_features_used(node), set())
931         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
932         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
933         node = black.lib2to3_parse("a[*b]")
934         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
935         node = black.lib2to3_parse("a[x, *y(), z] = t")
936         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
937         node = black.lib2to3_parse("def fn(*args: *T): pass")
938         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
939
940     def test_get_features_used_for_future_flags(self) -> None:
941         for src, features in [
942             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
943             (
944                 "from __future__ import (other, annotations)",
945                 {Feature.FUTURE_ANNOTATIONS},
946             ),
947             ("a = 1 + 2\nfrom something import annotations", set()),
948             ("from __future__ import x, y", set()),
949         ]:
950             with self.subTest(src=src, features=features):
951                 node = black.lib2to3_parse(src)
952                 future_imports = black.get_future_imports(node)
953                 self.assertEqual(
954                     black.get_features_used(node, future_imports=future_imports),
955                     features,
956                 )
957
958     def test_get_future_imports(self) -> None:
959         node = black.lib2to3_parse("\n")
960         self.assertEqual(set(), black.get_future_imports(node))
961         node = black.lib2to3_parse("from __future__ import black\n")
962         self.assertEqual({"black"}, black.get_future_imports(node))
963         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
964         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
965         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
966         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
967         node = black.lib2to3_parse(
968             "from __future__ import multiple\nfrom __future__ import imports\n"
969         )
970         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
971         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
972         self.assertEqual({"black"}, black.get_future_imports(node))
973         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
974         self.assertEqual({"black"}, black.get_future_imports(node))
975         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
976         self.assertEqual(set(), black.get_future_imports(node))
977         node = black.lib2to3_parse("from some.module import black\n")
978         self.assertEqual(set(), black.get_future_imports(node))
979         node = black.lib2to3_parse(
980             "from __future__ import unicode_literals as _unicode_literals"
981         )
982         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
983         node = black.lib2to3_parse(
984             "from __future__ import unicode_literals as _lol, print"
985         )
986         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
987
988     @pytest.mark.incompatible_with_mypyc
989     def test_debug_visitor(self) -> None:
990         source, _ = read_data("miscellaneous", "debug_visitor")
991         expected, _ = read_data("miscellaneous", "debug_visitor.out")
992         out_lines = []
993         err_lines = []
994
995         def out(msg: str, **kwargs: Any) -> None:
996             out_lines.append(msg)
997
998         def err(msg: str, **kwargs: Any) -> None:
999             err_lines.append(msg)
1000
1001         with patch("black.debug.out", out):
1002             DebugVisitor.show(source)
1003         actual = "\n".join(out_lines) + "\n"
1004         log_name = ""
1005         if expected != actual:
1006             log_name = black.dump_to_file(*out_lines)
1007         self.assertEqual(
1008             expected,
1009             actual,
1010             f"AST print out is different. Actual version dumped to {log_name}",
1011         )
1012
1013     def test_format_file_contents(self) -> None:
1014         mode = DEFAULT_MODE
1015         empty = ""
1016         with self.assertRaises(black.NothingChanged):
1017             black.format_file_contents(empty, mode=mode, fast=False)
1018         just_nl = "\n"
1019         with self.assertRaises(black.NothingChanged):
1020             black.format_file_contents(just_nl, mode=mode, fast=False)
1021         same = "j = [1, 2, 3]\n"
1022         with self.assertRaises(black.NothingChanged):
1023             black.format_file_contents(same, mode=mode, fast=False)
1024         different = "j = [1,2,3]"
1025         expected = same
1026         actual = black.format_file_contents(different, mode=mode, fast=False)
1027         self.assertEqual(expected, actual)
1028         invalid = "return if you can"
1029         with self.assertRaises(black.InvalidInput) as e:
1030             black.format_file_contents(invalid, mode=mode, fast=False)
1031         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1032
1033         mode = black.Mode(preview=True)
1034         just_crlf = "\r\n"
1035         with self.assertRaises(black.NothingChanged):
1036             black.format_file_contents(just_crlf, mode=mode, fast=False)
1037         just_whitespace_nl = "\n\t\n \n\t \n \t\n\n"
1038         actual = black.format_file_contents(just_whitespace_nl, mode=mode, fast=False)
1039         self.assertEqual("\n", actual)
1040         just_whitespace_crlf = "\r\n\t\r\n \r\n\t \r\n \t\r\n\r\n"
1041         actual = black.format_file_contents(just_whitespace_crlf, mode=mode, fast=False)
1042         self.assertEqual("\r\n", actual)
1043
1044     def test_endmarker(self) -> None:
1045         n = black.lib2to3_parse("\n")
1046         self.assertEqual(n.type, black.syms.file_input)
1047         self.assertEqual(len(n.children), 1)
1048         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1049
1050     @pytest.mark.incompatible_with_mypyc
1051     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1052     def test_assertFormatEqual(self) -> None:
1053         out_lines = []
1054         err_lines = []
1055
1056         def out(msg: str, **kwargs: Any) -> None:
1057             out_lines.append(msg)
1058
1059         def err(msg: str, **kwargs: Any) -> None:
1060             err_lines.append(msg)
1061
1062         with patch("black.output._out", out), patch("black.output._err", err):
1063             with self.assertRaises(AssertionError):
1064                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1065
1066         out_str = "".join(out_lines)
1067         self.assertIn("Expected tree:", out_str)
1068         self.assertIn("Actual tree:", out_str)
1069         self.assertEqual("".join(err_lines), "")
1070
1071     @event_loop()
1072     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1073     def test_works_in_mono_process_only_environment(self) -> None:
1074         with cache_dir() as workspace:
1075             for f in [
1076                 (workspace / "one.py").resolve(),
1077                 (workspace / "two.py").resolve(),
1078             ]:
1079                 f.write_text('print("hello")\n', encoding="utf-8")
1080             self.invokeBlack([str(workspace)])
1081
1082     @event_loop()
1083     def test_check_diff_use_together(self) -> None:
1084         with cache_dir():
1085             # Files which will be reformatted.
1086             src1 = get_case_path("miscellaneous", "string_quotes")
1087             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1088             # Files which will not be reformatted.
1089             src2 = get_case_path("simple_cases", "composition")
1090             self.invokeBlack([str(src2), "--diff", "--check"])
1091             # Multi file command.
1092             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1093
1094     def test_no_src_fails(self) -> None:
1095         with cache_dir():
1096             self.invokeBlack([], exit_code=1)
1097
1098     def test_src_and_code_fails(self) -> None:
1099         with cache_dir():
1100             self.invokeBlack([".", "-c", "0"], exit_code=1)
1101
1102     def test_broken_symlink(self) -> None:
1103         with cache_dir() as workspace:
1104             symlink = workspace / "broken_link.py"
1105             try:
1106                 symlink.symlink_to("nonexistent.py")
1107             except (OSError, NotImplementedError) as e:
1108                 self.skipTest(f"Can't create symlinks: {e}")
1109             self.invokeBlack([str(workspace.resolve())])
1110
1111     def test_single_file_force_pyi(self) -> None:
1112         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1113         contents, expected = read_data("miscellaneous", "force_pyi")
1114         with cache_dir() as workspace:
1115             path = (workspace / "file.py").resolve()
1116             path.write_text(contents, encoding="utf-8")
1117             self.invokeBlack([str(path), "--pyi"])
1118             actual = path.read_text(encoding="utf-8")
1119             # verify cache with --pyi is separate
1120             pyi_cache = black.Cache.read(pyi_mode)
1121             assert not pyi_cache.is_changed(path)
1122             normal_cache = black.Cache.read(DEFAULT_MODE)
1123             assert normal_cache.is_changed(path)
1124         self.assertFormatEqual(expected, actual)
1125         black.assert_equivalent(contents, actual)
1126         black.assert_stable(contents, actual, pyi_mode)
1127
1128     @event_loop()
1129     def test_multi_file_force_pyi(self) -> None:
1130         reg_mode = DEFAULT_MODE
1131         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1132         contents, expected = read_data("miscellaneous", "force_pyi")
1133         with cache_dir() as workspace:
1134             paths = [
1135                 (workspace / "file1.py").resolve(),
1136                 (workspace / "file2.py").resolve(),
1137             ]
1138             for path in paths:
1139                 path.write_text(contents, encoding="utf-8")
1140             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1141             for path in paths:
1142                 actual = path.read_text(encoding="utf-8")
1143                 self.assertEqual(actual, expected)
1144             # verify cache with --pyi is separate
1145             pyi_cache = black.Cache.read(pyi_mode)
1146             normal_cache = black.Cache.read(reg_mode)
1147             for path in paths:
1148                 assert not pyi_cache.is_changed(path)
1149                 assert normal_cache.is_changed(path)
1150
1151     def test_pipe_force_pyi(self) -> None:
1152         source, expected = read_data("miscellaneous", "force_pyi")
1153         result = CliRunner().invoke(
1154             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf-8"))
1155         )
1156         self.assertEqual(result.exit_code, 0)
1157         actual = result.output
1158         self.assertFormatEqual(actual, expected)
1159
1160     def test_single_file_force_py36(self) -> None:
1161         reg_mode = DEFAULT_MODE
1162         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1163         source, expected = read_data("miscellaneous", "force_py36")
1164         with cache_dir() as workspace:
1165             path = (workspace / "file.py").resolve()
1166             path.write_text(source, encoding="utf-8")
1167             self.invokeBlack([str(path), *PY36_ARGS])
1168             actual = path.read_text(encoding="utf-8")
1169             # verify cache with --target-version is separate
1170             py36_cache = black.Cache.read(py36_mode)
1171             assert not py36_cache.is_changed(path)
1172             normal_cache = black.Cache.read(reg_mode)
1173             assert normal_cache.is_changed(path)
1174         self.assertEqual(actual, expected)
1175
1176     @event_loop()
1177     def test_multi_file_force_py36(self) -> None:
1178         reg_mode = DEFAULT_MODE
1179         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1180         source, expected = read_data("miscellaneous", "force_py36")
1181         with cache_dir() as workspace:
1182             paths = [
1183                 (workspace / "file1.py").resolve(),
1184                 (workspace / "file2.py").resolve(),
1185             ]
1186             for path in paths:
1187                 path.write_text(source, encoding="utf-8")
1188             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1189             for path in paths:
1190                 actual = path.read_text(encoding="utf-8")
1191                 self.assertEqual(actual, expected)
1192             # verify cache with --target-version is separate
1193             pyi_cache = black.Cache.read(py36_mode)
1194             normal_cache = black.Cache.read(reg_mode)
1195             for path in paths:
1196                 assert not pyi_cache.is_changed(path)
1197                 assert normal_cache.is_changed(path)
1198
1199     def test_pipe_force_py36(self) -> None:
1200         source, expected = read_data("miscellaneous", "force_py36")
1201         result = CliRunner().invoke(
1202             black.main,
1203             ["-", "-q", "--target-version=py36"],
1204             input=BytesIO(source.encode("utf-8")),
1205         )
1206         self.assertEqual(result.exit_code, 0)
1207         actual = result.output
1208         self.assertFormatEqual(actual, expected)
1209
1210     @pytest.mark.incompatible_with_mypyc
1211     def test_reformat_one_with_stdin(self) -> None:
1212         with patch(
1213             "black.format_stdin_to_stdout",
1214             return_value=lambda *args, **kwargs: black.Changed.YES,
1215         ) as fsts:
1216             report = MagicMock()
1217             path = Path("-")
1218             black.reformat_one(
1219                 path,
1220                 fast=True,
1221                 write_back=black.WriteBack.YES,
1222                 mode=DEFAULT_MODE,
1223                 report=report,
1224             )
1225             fsts.assert_called_once()
1226             report.done.assert_called_with(path, black.Changed.YES)
1227
1228     @pytest.mark.incompatible_with_mypyc
1229     def test_reformat_one_with_stdin_filename(self) -> None:
1230         with patch(
1231             "black.format_stdin_to_stdout",
1232             return_value=lambda *args, **kwargs: black.Changed.YES,
1233         ) as fsts:
1234             report = MagicMock()
1235             p = "foo.py"
1236             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1237             expected = Path(p)
1238             black.reformat_one(
1239                 path,
1240                 fast=True,
1241                 write_back=black.WriteBack.YES,
1242                 mode=DEFAULT_MODE,
1243                 report=report,
1244             )
1245             fsts.assert_called_once_with(
1246                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1247             )
1248             # __BLACK_STDIN_FILENAME__ should have been stripped
1249             report.done.assert_called_with(expected, black.Changed.YES)
1250
1251     @pytest.mark.incompatible_with_mypyc
1252     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1253         with patch(
1254             "black.format_stdin_to_stdout",
1255             return_value=lambda *args, **kwargs: black.Changed.YES,
1256         ) as fsts:
1257             report = MagicMock()
1258             p = "foo.pyi"
1259             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1260             expected = Path(p)
1261             black.reformat_one(
1262                 path,
1263                 fast=True,
1264                 write_back=black.WriteBack.YES,
1265                 mode=DEFAULT_MODE,
1266                 report=report,
1267             )
1268             fsts.assert_called_once_with(
1269                 fast=True,
1270                 write_back=black.WriteBack.YES,
1271                 mode=replace(DEFAULT_MODE, is_pyi=True),
1272             )
1273             # __BLACK_STDIN_FILENAME__ should have been stripped
1274             report.done.assert_called_with(expected, black.Changed.YES)
1275
1276     @pytest.mark.incompatible_with_mypyc
1277     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1278         with patch(
1279             "black.format_stdin_to_stdout",
1280             return_value=lambda *args, **kwargs: black.Changed.YES,
1281         ) as fsts:
1282             report = MagicMock()
1283             p = "foo.ipynb"
1284             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1285             expected = Path(p)
1286             black.reformat_one(
1287                 path,
1288                 fast=True,
1289                 write_back=black.WriteBack.YES,
1290                 mode=DEFAULT_MODE,
1291                 report=report,
1292             )
1293             fsts.assert_called_once_with(
1294                 fast=True,
1295                 write_back=black.WriteBack.YES,
1296                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1297             )
1298             # __BLACK_STDIN_FILENAME__ should have been stripped
1299             report.done.assert_called_with(expected, black.Changed.YES)
1300
1301     @pytest.mark.incompatible_with_mypyc
1302     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1303         with patch(
1304             "black.format_stdin_to_stdout",
1305             return_value=lambda *args, **kwargs: black.Changed.YES,
1306         ) as fsts:
1307             report = MagicMock()
1308             # Even with an existing file, since we are forcing stdin, black
1309             # should output to stdout and not modify the file inplace
1310             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1311             # Make sure is_file actually returns True
1312             self.assertTrue(p.is_file())
1313             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1314             expected = Path(p)
1315             black.reformat_one(
1316                 path,
1317                 fast=True,
1318                 write_back=black.WriteBack.YES,
1319                 mode=DEFAULT_MODE,
1320                 report=report,
1321             )
1322             fsts.assert_called_once()
1323             # __BLACK_STDIN_FILENAME__ should have been stripped
1324             report.done.assert_called_with(expected, black.Changed.YES)
1325
1326     def test_reformat_one_with_stdin_empty(self) -> None:
1327         cases = [
1328             ("", ""),
1329             ("\n", "\n"),
1330             ("\r\n", "\r\n"),
1331             (" \t", ""),
1332             (" \t\n\t ", "\n"),
1333             (" \t\r\n\t ", "\r\n"),
1334         ]
1335
1336         def _new_wrapper(
1337             output: io.StringIO, io_TextIOWrapper: Type[io.TextIOWrapper]
1338         ) -> Callable[[Any, Any], io.TextIOWrapper]:
1339             def get_output(*args: Any, **kwargs: Any) -> io.TextIOWrapper:
1340                 if args == (sys.stdout.buffer,):
1341                     # It's `format_stdin_to_stdout()` calling `io.TextIOWrapper()`,
1342                     # return our mock object.
1343                     return output
1344                 # It's something else (i.e. `decode_bytes()`) calling
1345                 # `io.TextIOWrapper()`, pass through to the original implementation.
1346                 # See discussion in https://github.com/psf/black/pull/2489
1347                 return io_TextIOWrapper(*args, **kwargs)
1348
1349             return get_output
1350
1351         mode = black.Mode(preview=True)
1352         for content, expected in cases:
1353             output = io.StringIO()
1354             io_TextIOWrapper = io.TextIOWrapper
1355
1356             with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1357                 try:
1358                     black.format_stdin_to_stdout(
1359                         fast=True,
1360                         content=content,
1361                         write_back=black.WriteBack.YES,
1362                         mode=mode,
1363                     )
1364                 except io.UnsupportedOperation:
1365                     pass  # StringIO does not support detach
1366                 assert output.getvalue() == expected
1367
1368         # An empty string is the only test case for `preview=False`
1369         output = io.StringIO()
1370         io_TextIOWrapper = io.TextIOWrapper
1371         with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
1372             try:
1373                 black.format_stdin_to_stdout(
1374                     fast=True,
1375                     content="",
1376                     write_back=black.WriteBack.YES,
1377                     mode=DEFAULT_MODE,
1378                 )
1379             except io.UnsupportedOperation:
1380                 pass  # StringIO does not support detach
1381             assert output.getvalue() == ""
1382
1383     def test_invalid_cli_regex(self) -> None:
1384         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1385             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1386
1387     def test_required_version_matches_version(self) -> None:
1388         self.invokeBlack(
1389             ["--required-version", black.__version__, "-c", "0"],
1390             exit_code=0,
1391             ignore_config=True,
1392         )
1393
1394     def test_required_version_matches_partial_version(self) -> None:
1395         self.invokeBlack(
1396             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1397             exit_code=0,
1398             ignore_config=True,
1399         )
1400
1401     def test_required_version_does_not_match_on_minor_version(self) -> None:
1402         self.invokeBlack(
1403             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1404             exit_code=1,
1405             ignore_config=True,
1406         )
1407
1408     def test_required_version_does_not_match_version(self) -> None:
1409         result = BlackRunner().invoke(
1410             black.main,
1411             ["--required-version", "20.99b", "-c", "0"],
1412         )
1413         self.assertEqual(result.exit_code, 1)
1414         self.assertIn("required version", result.stderr)
1415
1416     def test_preserves_line_endings(self) -> None:
1417         with TemporaryDirectory() as workspace:
1418             test_file = Path(workspace) / "test.py"
1419             for nl in ["\n", "\r\n"]:
1420                 contents = nl.join(["def f(  ):", "    pass"])
1421                 test_file.write_bytes(contents.encode())
1422                 ff(test_file, write_back=black.WriteBack.YES)
1423                 updated_contents: bytes = test_file.read_bytes()
1424                 self.assertIn(nl.encode(), updated_contents)
1425                 if nl == "\n":
1426                     self.assertNotIn(b"\r\n", updated_contents)
1427
1428     def test_preserves_line_endings_via_stdin(self) -> None:
1429         for nl in ["\n", "\r\n"]:
1430             contents = nl.join(["def f(  ):", "    pass"])
1431             runner = BlackRunner()
1432             result = runner.invoke(
1433                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf-8"))
1434             )
1435             self.assertEqual(result.exit_code, 0)
1436             output = result.stdout_bytes
1437             self.assertIn(nl.encode("utf-8"), output)
1438             if nl == "\n":
1439                 self.assertNotIn(b"\r\n", output)
1440
1441     def test_normalize_line_endings(self) -> None:
1442         with TemporaryDirectory() as workspace:
1443             test_file = Path(workspace) / "test.py"
1444             for data, expected in (
1445                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1446                 (b"l\nl\r\n ", b"l\nl\n"),
1447             ):
1448                 test_file.write_bytes(data)
1449                 ff(test_file, write_back=black.WriteBack.YES)
1450                 self.assertEqual(test_file.read_bytes(), expected)
1451
1452     def test_assert_equivalent_different_asts(self) -> None:
1453         with self.assertRaises(AssertionError):
1454             black.assert_equivalent("{}", "None")
1455
1456     def test_root_logger_not_used_directly(self) -> None:
1457         def fail(*args: Any, **kwargs: Any) -> None:
1458             self.fail("Record created with root logger")
1459
1460         with patch.multiple(
1461             logging.root,
1462             debug=fail,
1463             info=fail,
1464             warning=fail,
1465             error=fail,
1466             critical=fail,
1467             log=fail,
1468         ):
1469             ff(THIS_DIR / "util.py")
1470
1471     def test_invalid_config_return_code(self) -> None:
1472         tmp_file = Path(black.dump_to_file())
1473         try:
1474             tmp_config = Path(black.dump_to_file())
1475             tmp_config.unlink()
1476             args = ["--config", str(tmp_config), str(tmp_file)]
1477             self.invokeBlack(args, exit_code=2, ignore_config=False)
1478         finally:
1479             tmp_file.unlink()
1480
1481     def test_parse_pyproject_toml(self) -> None:
1482         test_toml_file = THIS_DIR / "test.toml"
1483         config = black.parse_pyproject_toml(str(test_toml_file))
1484         self.assertEqual(config["verbose"], 1)
1485         self.assertEqual(config["check"], "no")
1486         self.assertEqual(config["diff"], "y")
1487         self.assertEqual(config["color"], True)
1488         self.assertEqual(config["line_length"], 79)
1489         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1490         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1491         self.assertEqual(config["exclude"], r"\.pyi?$")
1492         self.assertEqual(config["include"], r"\.py?$")
1493
1494     def test_parse_pyproject_toml_project_metadata(self) -> None:
1495         for test_toml, expected in [
1496             ("only_black_pyproject.toml", ["py310"]),
1497             ("only_metadata_pyproject.toml", ["py37", "py38", "py39", "py310"]),
1498             ("neither_pyproject.toml", None),
1499             ("both_pyproject.toml", ["py310"]),
1500         ]:
1501             test_toml_file = THIS_DIR / "data" / "project_metadata" / test_toml
1502             config = black.parse_pyproject_toml(str(test_toml_file))
1503             self.assertEqual(config.get("target_version"), expected)
1504
1505     def test_infer_target_version(self) -> None:
1506         for version, expected in [
1507             ("3.6", [TargetVersion.PY36]),
1508             ("3.11.0rc1", [TargetVersion.PY311]),
1509             (">=3.10", [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312]),
1510             (
1511                 ">=3.10.6",
1512                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1513             ),
1514             ("<3.6", [TargetVersion.PY33, TargetVersion.PY34, TargetVersion.PY35]),
1515             (">3.7,<3.10", [TargetVersion.PY38, TargetVersion.PY39]),
1516             (
1517                 ">3.7,!=3.8,!=3.9",
1518                 [TargetVersion.PY310, TargetVersion.PY311, TargetVersion.PY312],
1519             ),
1520             (
1521                 "> 3.9.4, != 3.10.3",
1522                 [
1523                     TargetVersion.PY39,
1524                     TargetVersion.PY310,
1525                     TargetVersion.PY311,
1526                     TargetVersion.PY312,
1527                 ],
1528             ),
1529             (
1530                 "!=3.3,!=3.4",
1531                 [
1532                     TargetVersion.PY35,
1533                     TargetVersion.PY36,
1534                     TargetVersion.PY37,
1535                     TargetVersion.PY38,
1536                     TargetVersion.PY39,
1537                     TargetVersion.PY310,
1538                     TargetVersion.PY311,
1539                     TargetVersion.PY312,
1540                 ],
1541             ),
1542             (
1543                 "==3.*",
1544                 [
1545                     TargetVersion.PY33,
1546                     TargetVersion.PY34,
1547                     TargetVersion.PY35,
1548                     TargetVersion.PY36,
1549                     TargetVersion.PY37,
1550                     TargetVersion.PY38,
1551                     TargetVersion.PY39,
1552                     TargetVersion.PY310,
1553                     TargetVersion.PY311,
1554                     TargetVersion.PY312,
1555                 ],
1556             ),
1557             ("==3.8.*", [TargetVersion.PY38]),
1558             (None, None),
1559             ("", None),
1560             ("invalid", None),
1561             ("==invalid", None),
1562             (">3.9,!=invalid", None),
1563             ("3", None),
1564             ("3.2", None),
1565             ("2.7.18", None),
1566             ("==2.7", None),
1567             (">3.10,<3.11", None),
1568         ]:
1569             test_toml = {"project": {"requires-python": version}}
1570             result = black.files.infer_target_version(test_toml)
1571             self.assertEqual(result, expected)
1572
1573     def test_read_pyproject_toml(self) -> None:
1574         test_toml_file = THIS_DIR / "test.toml"
1575         fake_ctx = FakeContext()
1576         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1577         config = fake_ctx.default_map
1578         self.assertEqual(config["verbose"], "1")
1579         self.assertEqual(config["check"], "no")
1580         self.assertEqual(config["diff"], "y")
1581         self.assertEqual(config["color"], "True")
1582         self.assertEqual(config["line_length"], "79")
1583         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1584         self.assertEqual(config["exclude"], r"\.pyi?$")
1585         self.assertEqual(config["include"], r"\.py?$")
1586
1587     def test_read_pyproject_toml_from_stdin(self) -> None:
1588         with TemporaryDirectory() as workspace:
1589             root = Path(workspace)
1590
1591             src_dir = root / "src"
1592             src_dir.mkdir()
1593
1594             src_pyproject = src_dir / "pyproject.toml"
1595             src_pyproject.touch()
1596
1597             test_toml_content = (THIS_DIR / "test.toml").read_text(encoding="utf-8")
1598             src_pyproject.write_text(test_toml_content, encoding="utf-8")
1599
1600             src_python = src_dir / "foo.py"
1601             src_python.touch()
1602
1603             fake_ctx = FakeContext()
1604             fake_ctx.params["src"] = ("-",)
1605             fake_ctx.params["stdin_filename"] = str(src_python)
1606
1607             with change_directory(root):
1608                 black.read_pyproject_toml(fake_ctx, FakeParameter(), None)
1609
1610             config = fake_ctx.default_map
1611             self.assertEqual(config["verbose"], "1")
1612             self.assertEqual(config["check"], "no")
1613             self.assertEqual(config["diff"], "y")
1614             self.assertEqual(config["color"], "True")
1615             self.assertEqual(config["line_length"], "79")
1616             self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1617             self.assertEqual(config["exclude"], r"\.pyi?$")
1618             self.assertEqual(config["include"], r"\.py?$")
1619
1620     @pytest.mark.incompatible_with_mypyc
1621     def test_find_project_root(self) -> None:
1622         with TemporaryDirectory() as workspace:
1623             root = Path(workspace)
1624             test_dir = root / "test"
1625             test_dir.mkdir()
1626
1627             src_dir = root / "src"
1628             src_dir.mkdir()
1629
1630             root_pyproject = root / "pyproject.toml"
1631             root_pyproject.touch()
1632             src_pyproject = src_dir / "pyproject.toml"
1633             src_pyproject.touch()
1634             src_python = src_dir / "foo.py"
1635             src_python.touch()
1636
1637             self.assertEqual(
1638                 black.find_project_root((src_dir, test_dir)),
1639                 (root.resolve(), "pyproject.toml"),
1640             )
1641             self.assertEqual(
1642                 black.find_project_root((src_dir,)),
1643                 (src_dir.resolve(), "pyproject.toml"),
1644             )
1645             self.assertEqual(
1646                 black.find_project_root((src_python,)),
1647                 (src_dir.resolve(), "pyproject.toml"),
1648             )
1649
1650             with change_directory(test_dir):
1651                 self.assertEqual(
1652                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1653                     (src_dir.resolve(), "pyproject.toml"),
1654                 )
1655
1656     @patch(
1657         "black.files.find_user_pyproject_toml",
1658     )
1659     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1660         find_user_pyproject_toml.side_effect = RuntimeError()
1661
1662         with redirect_stderr(io.StringIO()) as stderr:
1663             result = black.files.find_pyproject_toml(
1664                 path_search_start=(str(Path.cwd().root),)
1665             )
1666
1667         assert result is None
1668         err = stderr.getvalue()
1669         assert "Ignoring user configuration" in err
1670
1671     @patch(
1672         "black.files.find_user_pyproject_toml",
1673         black.files.find_user_pyproject_toml.__wrapped__,
1674     )
1675     def test_find_user_pyproject_toml_linux(self) -> None:
1676         if system() == "Windows":
1677             return
1678
1679         # Test if XDG_CONFIG_HOME is checked
1680         with TemporaryDirectory() as workspace:
1681             tmp_user_config = Path(workspace) / "black"
1682             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1683                 self.assertEqual(
1684                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1685                 )
1686
1687         # Test fallback for XDG_CONFIG_HOME
1688         with patch.dict("os.environ"):
1689             os.environ.pop("XDG_CONFIG_HOME", None)
1690             fallback_user_config = Path("~/.config").expanduser() / "black"
1691             self.assertEqual(
1692                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1693             )
1694
1695     def test_find_user_pyproject_toml_windows(self) -> None:
1696         if system() != "Windows":
1697             return
1698
1699         user_config_path = Path.home() / ".black"
1700         self.assertEqual(
1701             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1702         )
1703
1704     def test_bpo_33660_workaround(self) -> None:
1705         if system() == "Windows":
1706             return
1707
1708         # https://bugs.python.org/issue33660
1709         root = Path("/")
1710         with change_directory(root):
1711             path = Path("workspace") / "project"
1712             report = black.Report(verbose=True)
1713             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1714             self.assertEqual(normalized_path, "workspace/project")
1715
1716     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1717         if system() != "Windows":
1718             return
1719
1720         with TemporaryDirectory() as workspace:
1721             root = Path(workspace)
1722             junction_dir = root / "junction"
1723             junction_target_outside_of_root = root / ".."
1724             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1725
1726             report = black.Report(verbose=True)
1727             normalized_path = black.normalize_path_maybe_ignore(
1728                 junction_dir, root, report
1729             )
1730             # Manually delete for Python < 3.8
1731             os.system(f"rmdir {junction_dir}")
1732
1733             self.assertEqual(normalized_path, None)
1734
1735     def test_newline_comment_interaction(self) -> None:
1736         source = "class A:\\\r\n# type: ignore\n pass\n"
1737         output = black.format_str(source, mode=DEFAULT_MODE)
1738         black.assert_stable(source, output, mode=DEFAULT_MODE)
1739
1740     def test_bpo_2142_workaround(self) -> None:
1741         # https://bugs.python.org/issue2142
1742
1743         source, _ = read_data("miscellaneous", "missing_final_newline")
1744         # read_data adds a trailing newline
1745         source = source.rstrip()
1746         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1747         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1748         diff_header = re.compile(
1749             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1750             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
1751         )
1752         try:
1753             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1754             self.assertEqual(result.exit_code, 0)
1755         finally:
1756             os.unlink(tmp_file)
1757         actual = result.output
1758         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1759         self.assertEqual(actual, expected)
1760
1761     @staticmethod
1762     def compare_results(
1763         result: click.testing.Result, expected_value: str, expected_exit_code: int
1764     ) -> None:
1765         """Helper method to test the value and exit code of a click Result."""
1766         assert (
1767             result.output == expected_value
1768         ), "The output did not match the expected value."
1769         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1770
1771     def test_code_option(self) -> None:
1772         """Test the code option with no changes."""
1773         code = 'print("Hello world")\n'
1774         args = ["--code", code]
1775         result = CliRunner().invoke(black.main, args)
1776
1777         self.compare_results(result, code, 0)
1778
1779     def test_code_option_changed(self) -> None:
1780         """Test the code option when changes are required."""
1781         code = "print('hello world')"
1782         formatted = black.format_str(code, mode=DEFAULT_MODE)
1783
1784         args = ["--code", code]
1785         result = CliRunner().invoke(black.main, args)
1786
1787         self.compare_results(result, formatted, 0)
1788
1789     def test_code_option_check(self) -> None:
1790         """Test the code option when check is passed."""
1791         args = ["--check", "--code", 'print("Hello world")\n']
1792         result = CliRunner().invoke(black.main, args)
1793         self.compare_results(result, "", 0)
1794
1795     def test_code_option_check_changed(self) -> None:
1796         """Test the code option when changes are required, and check is passed."""
1797         args = ["--check", "--code", "print('hello world')"]
1798         result = CliRunner().invoke(black.main, args)
1799         self.compare_results(result, "", 1)
1800
1801     def test_code_option_diff(self) -> None:
1802         """Test the code option when diff is passed."""
1803         code = "print('hello world')"
1804         formatted = black.format_str(code, mode=DEFAULT_MODE)
1805         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1806
1807         args = ["--diff", "--code", code]
1808         result = CliRunner().invoke(black.main, args)
1809
1810         # Remove time from diff
1811         output = DIFF_TIME.sub("", result.output)
1812
1813         assert output == result_diff, "The output did not match the expected value."
1814         assert result.exit_code == 0, "The exit code is incorrect."
1815
1816     def test_code_option_color_diff(self) -> None:
1817         """Test the code option when color and diff are passed."""
1818         code = "print('hello world')"
1819         formatted = black.format_str(code, mode=DEFAULT_MODE)
1820
1821         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1822         result_diff = color_diff(result_diff)
1823
1824         args = ["--diff", "--color", "--code", code]
1825         result = CliRunner().invoke(black.main, args)
1826
1827         # Remove time from diff
1828         output = DIFF_TIME.sub("", result.output)
1829
1830         assert output == result_diff, "The output did not match the expected value."
1831         assert result.exit_code == 0, "The exit code is incorrect."
1832
1833     @pytest.mark.incompatible_with_mypyc
1834     def test_code_option_safe(self) -> None:
1835         """Test that the code option throws an error when the sanity checks fail."""
1836         # Patch black.assert_equivalent to ensure the sanity checks fail
1837         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1838             code = 'print("Hello world")'
1839             error_msg = f"{code}\nerror: cannot format <string>: \n"
1840
1841             args = ["--safe", "--code", code]
1842             result = CliRunner().invoke(black.main, args)
1843
1844             self.compare_results(result, error_msg, 123)
1845
1846     def test_code_option_fast(self) -> None:
1847         """Test that the code option ignores errors when the sanity checks fail."""
1848         # Patch black.assert_equivalent to ensure the sanity checks fail
1849         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1850             code = 'print("Hello world")'
1851             formatted = black.format_str(code, mode=DEFAULT_MODE)
1852
1853             args = ["--fast", "--code", code]
1854             result = CliRunner().invoke(black.main, args)
1855
1856             self.compare_results(result, formatted, 0)
1857
1858     @pytest.mark.incompatible_with_mypyc
1859     def test_code_option_config(self) -> None:
1860         """
1861         Test that the code option finds the pyproject.toml in the current directory.
1862         """
1863         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1864             args = ["--code", "print"]
1865             # This is the only directory known to contain a pyproject.toml
1866             with change_directory(PROJECT_ROOT):
1867                 CliRunner().invoke(black.main, args)
1868                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1869
1870             assert (
1871                 len(parse.mock_calls) >= 1
1872             ), "Expected config parse to be called with the current directory."
1873
1874             _, call_args, _ = parse.mock_calls[0]
1875             assert (
1876                 call_args[0].lower() == str(pyproject_path).lower()
1877             ), "Incorrect config loaded."
1878
1879     @pytest.mark.incompatible_with_mypyc
1880     def test_code_option_parent_config(self) -> None:
1881         """
1882         Test that the code option finds the pyproject.toml in the parent directory.
1883         """
1884         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1885             with change_directory(THIS_DIR):
1886                 args = ["--code", "print"]
1887                 CliRunner().invoke(black.main, args)
1888
1889                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1890                 assert (
1891                     len(parse.mock_calls) >= 1
1892                 ), "Expected config parse to be called with the current directory."
1893
1894                 _, call_args, _ = parse.mock_calls[0]
1895                 assert (
1896                     call_args[0].lower() == str(pyproject_path).lower()
1897                 ), "Incorrect config loaded."
1898
1899     def test_for_handled_unexpected_eof_error(self) -> None:
1900         """
1901         Test that an unexpected EOF SyntaxError is nicely presented.
1902         """
1903         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1904             black.lib2to3_parse("print(", {})
1905
1906         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1907
1908     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1909         with pytest.raises(AssertionError) as err:
1910             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1911
1912         err.match("--safe")
1913         # Unfortunately the SyntaxError message has changed in newer versions so we
1914         # can't match it directly.
1915         err.match("invalid character")
1916         err.match(r"\(<unknown>, line 1\)")
1917
1918
1919 class TestCaching:
1920     def test_get_cache_dir(
1921         self,
1922         tmp_path: Path,
1923         monkeypatch: pytest.MonkeyPatch,
1924     ) -> None:
1925         # Create multiple cache directories
1926         workspace1 = tmp_path / "ws1"
1927         workspace1.mkdir()
1928         workspace2 = tmp_path / "ws2"
1929         workspace2.mkdir()
1930
1931         # Force user_cache_dir to use the temporary directory for easier assertions
1932         patch_user_cache_dir = patch(
1933             target="black.cache.user_cache_dir",
1934             autospec=True,
1935             return_value=str(workspace1),
1936         )
1937
1938         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1939         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1940         with patch_user_cache_dir:
1941             assert get_cache_dir() == workspace1
1942
1943         # If it is set, use the path provided in the env var.
1944         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1945         assert get_cache_dir() == workspace2
1946
1947     def test_cache_broken_file(self) -> None:
1948         mode = DEFAULT_MODE
1949         with cache_dir() as workspace:
1950             cache_file = get_cache_file(mode)
1951             cache_file.write_text("this is not a pickle", encoding="utf-8")
1952             assert black.Cache.read(mode).file_data == {}
1953             src = (workspace / "test.py").resolve()
1954             src.write_text("print('hello')", encoding="utf-8")
1955             invokeBlack([str(src)])
1956             cache = black.Cache.read(mode)
1957             assert not cache.is_changed(src)
1958
1959     def test_cache_single_file_already_cached(self) -> None:
1960         mode = DEFAULT_MODE
1961         with cache_dir() as workspace:
1962             src = (workspace / "test.py").resolve()
1963             src.write_text("print('hello')", encoding="utf-8")
1964             cache = black.Cache.read(mode)
1965             cache.write([src])
1966             invokeBlack([str(src)])
1967             assert src.read_text(encoding="utf-8") == "print('hello')"
1968
1969     @event_loop()
1970     def test_cache_multiple_files(self) -> None:
1971         mode = DEFAULT_MODE
1972         with cache_dir() as workspace, patch(
1973             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1974         ):
1975             one = (workspace / "one.py").resolve()
1976             one.write_text("print('hello')", encoding="utf-8")
1977             two = (workspace / "two.py").resolve()
1978             two.write_text("print('hello')", encoding="utf-8")
1979             cache = black.Cache.read(mode)
1980             cache.write([one])
1981             invokeBlack([str(workspace)])
1982             assert one.read_text(encoding="utf-8") == "print('hello')"
1983             assert two.read_text(encoding="utf-8") == 'print("hello")\n'
1984             cache = black.Cache.read(mode)
1985             assert not cache.is_changed(one)
1986             assert not cache.is_changed(two)
1987
1988     @pytest.mark.incompatible_with_mypyc
1989     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1990     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1991         mode = DEFAULT_MODE
1992         with cache_dir() as workspace:
1993             src = (workspace / "test.py").resolve()
1994             src.write_text("print('hello')", encoding="utf-8")
1995             with patch.object(black.Cache, "read") as read_cache, patch.object(
1996                 black.Cache, "write"
1997             ) as write_cache:
1998                 cmd = [str(src), "--diff"]
1999                 if color:
2000                     cmd.append("--color")
2001                 invokeBlack(cmd)
2002                 cache_file = get_cache_file(mode)
2003                 assert cache_file.exists() is False
2004                 read_cache.assert_called_once()
2005                 write_cache.assert_not_called()
2006
2007     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
2008     @event_loop()
2009     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
2010         with cache_dir() as workspace:
2011             for tag in range(0, 4):
2012                 src = (workspace / f"test{tag}.py").resolve()
2013                 src.write_text("print('hello')", encoding="utf-8")
2014             with patch(
2015                 "black.concurrency.Manager", wraps=multiprocessing.Manager
2016             ) as mgr:
2017                 cmd = ["--diff", str(workspace)]
2018                 if color:
2019                     cmd.append("--color")
2020                 invokeBlack(cmd, exit_code=0)
2021                 # this isn't quite doing what we want, but if it _isn't_
2022                 # called then we cannot be using the lock it provides
2023                 mgr.assert_called()
2024
2025     def test_no_cache_when_stdin(self) -> None:
2026         mode = DEFAULT_MODE
2027         with cache_dir():
2028             result = CliRunner().invoke(
2029                 black.main, ["-"], input=BytesIO(b"print('hello')")
2030             )
2031             assert not result.exit_code
2032             cache_file = get_cache_file(mode)
2033             assert not cache_file.exists()
2034
2035     def test_read_cache_no_cachefile(self) -> None:
2036         mode = DEFAULT_MODE
2037         with cache_dir():
2038             assert black.Cache.read(mode).file_data == {}
2039
2040     def test_write_cache_read_cache(self) -> None:
2041         mode = DEFAULT_MODE
2042         with cache_dir() as workspace:
2043             src = (workspace / "test.py").resolve()
2044             src.touch()
2045             write_cache = black.Cache.read(mode)
2046             write_cache.write([src])
2047             read_cache = black.Cache.read(mode)
2048             assert not read_cache.is_changed(src)
2049
2050     @pytest.mark.incompatible_with_mypyc
2051     def test_filter_cached(self) -> None:
2052         with TemporaryDirectory() as workspace:
2053             path = Path(workspace)
2054             uncached = (path / "uncached").resolve()
2055             cached = (path / "cached").resolve()
2056             cached_but_changed = (path / "changed").resolve()
2057             uncached.touch()
2058             cached.touch()
2059             cached_but_changed.touch()
2060             cache = black.Cache.read(DEFAULT_MODE)
2061
2062             orig_func = black.Cache.get_file_data
2063
2064             def wrapped_func(path: Path) -> FileData:
2065                 if path == cached:
2066                     return orig_func(path)
2067                 if path == cached_but_changed:
2068                     return FileData(0.0, 0, "")
2069                 raise AssertionError
2070
2071             with patch.object(black.Cache, "get_file_data", side_effect=wrapped_func):
2072                 cache.write([cached, cached_but_changed])
2073             todo, done = cache.filtered_cached({uncached, cached, cached_but_changed})
2074             assert todo == {uncached, cached_but_changed}
2075             assert done == {cached}
2076
2077     def test_filter_cached_hash(self) -> None:
2078         with TemporaryDirectory() as workspace:
2079             path = Path(workspace)
2080             src = (path / "test.py").resolve()
2081             src.write_text("print('hello')", encoding="utf-8")
2082             st = src.stat()
2083             cache = black.Cache.read(DEFAULT_MODE)
2084             cache.write([src])
2085             cached_file_data = cache.file_data[str(src)]
2086
2087             todo, done = cache.filtered_cached([src])
2088             assert todo == set()
2089             assert done == {src}
2090             assert cached_file_data.st_mtime == st.st_mtime
2091
2092             # Modify st_mtime
2093             cached_file_data = cache.file_data[str(src)] = FileData(
2094                 cached_file_data.st_mtime - 1,
2095                 cached_file_data.st_size,
2096                 cached_file_data.hash,
2097             )
2098             todo, done = cache.filtered_cached([src])
2099             assert todo == set()
2100             assert done == {src}
2101             assert cached_file_data.st_mtime < st.st_mtime
2102             assert cached_file_data.st_size == st.st_size
2103             assert cached_file_data.hash == black.Cache.hash_digest(src)
2104
2105             # Modify contents
2106             src.write_text("print('hello world')", encoding="utf-8")
2107             new_st = src.stat()
2108             todo, done = cache.filtered_cached([src])
2109             assert todo == {src}
2110             assert done == set()
2111             assert cached_file_data.st_mtime < new_st.st_mtime
2112             assert cached_file_data.st_size != new_st.st_size
2113             assert cached_file_data.hash != black.Cache.hash_digest(src)
2114
2115     def test_write_cache_creates_directory_if_needed(self) -> None:
2116         mode = DEFAULT_MODE
2117         with cache_dir(exists=False) as workspace:
2118             assert not workspace.exists()
2119             cache = black.Cache.read(mode)
2120             cache.write([])
2121             assert workspace.exists()
2122
2123     @event_loop()
2124     def test_failed_formatting_does_not_get_cached(self) -> None:
2125         mode = DEFAULT_MODE
2126         with cache_dir() as workspace, patch(
2127             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
2128         ):
2129             failing = (workspace / "failing.py").resolve()
2130             failing.write_text("not actually python", encoding="utf-8")
2131             clean = (workspace / "clean.py").resolve()
2132             clean.write_text('print("hello")\n', encoding="utf-8")
2133             invokeBlack([str(workspace)], exit_code=123)
2134             cache = black.Cache.read(mode)
2135             assert cache.is_changed(failing)
2136             assert not cache.is_changed(clean)
2137
2138     def test_write_cache_write_fail(self) -> None:
2139         mode = DEFAULT_MODE
2140         with cache_dir():
2141             cache = black.Cache.read(mode)
2142             with patch.object(Path, "open") as mock:
2143                 mock.side_effect = OSError
2144                 cache.write([])
2145
2146     def test_read_cache_line_lengths(self) -> None:
2147         mode = DEFAULT_MODE
2148         short_mode = replace(DEFAULT_MODE, line_length=1)
2149         with cache_dir() as workspace:
2150             path = (workspace / "file.py").resolve()
2151             path.touch()
2152             cache = black.Cache.read(mode)
2153             cache.write([path])
2154             one = black.Cache.read(mode)
2155             assert not one.is_changed(path)
2156             two = black.Cache.read(short_mode)
2157             assert two.is_changed(path)
2158
2159
2160 def assert_collected_sources(
2161     src: Sequence[Union[str, Path]],
2162     expected: Sequence[Union[str, Path]],
2163     *,
2164     root: Optional[Path] = None,
2165     exclude: Optional[str] = None,
2166     include: Optional[str] = None,
2167     extend_exclude: Optional[str] = None,
2168     force_exclude: Optional[str] = None,
2169     stdin_filename: Optional[str] = None,
2170 ) -> None:
2171     gs_src = tuple(str(Path(s)) for s in src)
2172     gs_expected = [Path(s) for s in expected]
2173     gs_exclude = None if exclude is None else compile_pattern(exclude)
2174     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
2175     gs_extend_exclude = (
2176         None if extend_exclude is None else compile_pattern(extend_exclude)
2177     )
2178     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
2179     collected = black.get_sources(
2180         root=root or THIS_DIR,
2181         src=gs_src,
2182         quiet=False,
2183         verbose=False,
2184         include=gs_include,
2185         exclude=gs_exclude,
2186         extend_exclude=gs_extend_exclude,
2187         force_exclude=gs_force_exclude,
2188         report=black.Report(),
2189         stdin_filename=stdin_filename,
2190     )
2191     assert sorted(collected) == sorted(gs_expected)
2192
2193
2194 class TestFileCollection:
2195     def test_include_exclude(self) -> None:
2196         path = THIS_DIR / "data" / "include_exclude_tests"
2197         src = [path]
2198         expected = [
2199             Path(path / "b/dont_exclude/a.py"),
2200             Path(path / "b/dont_exclude/a.pyi"),
2201         ]
2202         assert_collected_sources(
2203             src,
2204             expected,
2205             include=r"\.pyi?$",
2206             exclude=r"/exclude/|/\.definitely_exclude/",
2207         )
2208
2209     def test_gitignore_used_as_default(self) -> None:
2210         base = Path(DATA_DIR / "include_exclude_tests")
2211         expected = [
2212             base / "b/.definitely_exclude/a.py",
2213             base / "b/.definitely_exclude/a.pyi",
2214         ]
2215         src = [base / "b/"]
2216         assert_collected_sources(src, expected, root=base, extend_exclude=r"/exclude/")
2217
2218     def test_gitignore_used_on_multiple_sources(self) -> None:
2219         root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
2220         expected = [
2221             root / "dir1" / "b.py",
2222             root / "dir2" / "b.py",
2223         ]
2224         src = [root / "dir1", root / "dir2"]
2225         assert_collected_sources(src, expected, root=root)
2226
2227     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2228     def test_exclude_for_issue_1572(self) -> None:
2229         # Exclude shouldn't touch files that were explicitly given to Black through the
2230         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
2231         # https://github.com/psf/black/issues/1572
2232         path = DATA_DIR / "include_exclude_tests"
2233         src = [path / "b/exclude/a.py"]
2234         expected = [path / "b/exclude/a.py"]
2235         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2236
2237     def test_gitignore_exclude(self) -> None:
2238         path = THIS_DIR / "data" / "include_exclude_tests"
2239         include = re.compile(r"\.pyi?$")
2240         exclude = re.compile(r"")
2241         report = black.Report()
2242         gitignore = PathSpec.from_lines(
2243             "gitwildmatch", ["exclude/", ".definitely_exclude"]
2244         )
2245         sources: List[Path] = []
2246         expected = [
2247             Path(path / "b/dont_exclude/a.py"),
2248             Path(path / "b/dont_exclude/a.pyi"),
2249         ]
2250         this_abs = THIS_DIR.resolve()
2251         sources.extend(
2252             black.gen_python_files(
2253                 path.iterdir(),
2254                 this_abs,
2255                 include,
2256                 exclude,
2257                 None,
2258                 None,
2259                 report,
2260                 {path: gitignore},
2261                 verbose=False,
2262                 quiet=False,
2263             )
2264         )
2265         assert sorted(expected) == sorted(sources)
2266
2267     def test_nested_gitignore(self) -> None:
2268         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
2269         include = re.compile(r"\.pyi?$")
2270         exclude = re.compile(r"")
2271         root_gitignore = black.files.get_gitignore(path)
2272         report = black.Report()
2273         expected: List[Path] = [
2274             Path(path / "x.py"),
2275             Path(path / "root/b.py"),
2276             Path(path / "root/c.py"),
2277             Path(path / "root/child/c.py"),
2278         ]
2279         this_abs = THIS_DIR.resolve()
2280         sources = list(
2281             black.gen_python_files(
2282                 path.iterdir(),
2283                 this_abs,
2284                 include,
2285                 exclude,
2286                 None,
2287                 None,
2288                 report,
2289                 {path: root_gitignore},
2290                 verbose=False,
2291                 quiet=False,
2292             )
2293         )
2294         assert sorted(expected) == sorted(sources)
2295
2296     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2297         # https://github.com/psf/black/issues/2598
2298         path = Path(DATA_DIR / "nested_gitignore_tests")
2299         src = Path(path / "root" / "child")
2300         expected = [src / "a.py", src / "c.py"]
2301         assert_collected_sources([src], expected)
2302
2303     def test_invalid_gitignore(self) -> None:
2304         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2305         empty_config = path / "pyproject.toml"
2306         result = BlackRunner().invoke(
2307             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2308         )
2309         assert result.exit_code == 1
2310         assert result.stderr_bytes is not None
2311
2312         gitignore = path / ".gitignore"
2313         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2314
2315     def test_invalid_nested_gitignore(self) -> None:
2316         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2317         empty_config = path / "pyproject.toml"
2318         result = BlackRunner().invoke(
2319             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2320         )
2321         assert result.exit_code == 1
2322         assert result.stderr_bytes is not None
2323
2324         gitignore = path / "a" / ".gitignore"
2325         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2326
2327     def test_gitignore_that_ignores_subfolders(self) -> None:
2328         # If gitignore with */* is in root
2329         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
2330         expected = [root / "b.py"]
2331         assert_collected_sources([root], expected, root=root)
2332
2333         # If .gitignore with */* is nested
2334         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2335         expected = [
2336             root / "a.py",
2337             root / "subdir" / "b.py",
2338         ]
2339         assert_collected_sources([root], expected, root=root)
2340
2341         # If command is executed from outer dir
2342         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2343         target = root / "subdir"
2344         expected = [target / "b.py"]
2345         assert_collected_sources([target], expected, root=root)
2346
2347     def test_empty_include(self) -> None:
2348         path = DATA_DIR / "include_exclude_tests"
2349         src = [path]
2350         expected = [
2351             Path(path / "b/exclude/a.pie"),
2352             Path(path / "b/exclude/a.py"),
2353             Path(path / "b/exclude/a.pyi"),
2354             Path(path / "b/dont_exclude/a.pie"),
2355             Path(path / "b/dont_exclude/a.py"),
2356             Path(path / "b/dont_exclude/a.pyi"),
2357             Path(path / "b/.definitely_exclude/a.pie"),
2358             Path(path / "b/.definitely_exclude/a.py"),
2359             Path(path / "b/.definitely_exclude/a.pyi"),
2360             Path(path / ".gitignore"),
2361             Path(path / "pyproject.toml"),
2362         ]
2363         # Setting exclude explicitly to an empty string to block .gitignore usage.
2364         assert_collected_sources(src, expected, include="", exclude="")
2365
2366     def test_extend_exclude(self) -> None:
2367         path = DATA_DIR / "include_exclude_tests"
2368         src = [path]
2369         expected = [
2370             Path(path / "b/exclude/a.py"),
2371             Path(path / "b/dont_exclude/a.py"),
2372         ]
2373         assert_collected_sources(
2374             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2375         )
2376
2377     @pytest.mark.incompatible_with_mypyc
2378     def test_symlinks(self) -> None:
2379         path = MagicMock()
2380         root = THIS_DIR.resolve()
2381         include = re.compile(black.DEFAULT_INCLUDES)
2382         exclude = re.compile(black.DEFAULT_EXCLUDES)
2383         report = black.Report()
2384         gitignore = PathSpec.from_lines("gitwildmatch", [])
2385
2386         regular = MagicMock()
2387         outside_root_symlink = MagicMock()
2388         ignored_symlink = MagicMock()
2389
2390         path.iterdir.return_value = [regular, outside_root_symlink, ignored_symlink]
2391
2392         regular.absolute.return_value = root / "regular.py"
2393         regular.resolve.return_value = root / "regular.py"
2394         regular.is_dir.return_value = False
2395
2396         outside_root_symlink.absolute.return_value = root / "symlink.py"
2397         outside_root_symlink.resolve.return_value = Path("/nowhere")
2398
2399         ignored_symlink.absolute.return_value = root / ".mypy_cache" / "symlink.py"
2400
2401         files = list(
2402             black.gen_python_files(
2403                 path.iterdir(),
2404                 root,
2405                 include,
2406                 exclude,
2407                 None,
2408                 None,
2409                 report,
2410                 {path: gitignore},
2411                 verbose=False,
2412                 quiet=False,
2413             )
2414         )
2415         assert files == [regular]
2416
2417         path.iterdir.assert_called_once()
2418         outside_root_symlink.resolve.assert_called_once()
2419         ignored_symlink.resolve.assert_not_called()
2420
2421     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2422     def test_get_sources_with_stdin(self) -> None:
2423         src = ["-"]
2424         expected = ["-"]
2425         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2426
2427     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2428     def test_get_sources_with_stdin_filename(self) -> None:
2429         src = ["-"]
2430         stdin_filename = str(THIS_DIR / "data/collections.py")
2431         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2432         assert_collected_sources(
2433             src,
2434             expected,
2435             exclude=r"/exclude/a\.py",
2436             stdin_filename=stdin_filename,
2437         )
2438
2439     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2440     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2441         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2442         # file being passed directly. This is the same as
2443         # test_exclude_for_issue_1572
2444         path = DATA_DIR / "include_exclude_tests"
2445         src = ["-"]
2446         stdin_filename = str(path / "b/exclude/a.py")
2447         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2448         assert_collected_sources(
2449             src,
2450             expected,
2451             exclude=r"/exclude/|a\.py",
2452             stdin_filename=stdin_filename,
2453         )
2454
2455     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2456     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2457         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2458         # file being passed directly. This is the same as
2459         # test_exclude_for_issue_1572
2460         src = ["-"]
2461         path = THIS_DIR / "data" / "include_exclude_tests"
2462         stdin_filename = str(path / "b/exclude/a.py")
2463         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2464         assert_collected_sources(
2465             src,
2466             expected,
2467             extend_exclude=r"/exclude/|a\.py",
2468             stdin_filename=stdin_filename,
2469         )
2470
2471     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2472     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2473         # Force exclude should exclude the file when passing it through
2474         # stdin_filename
2475         path = THIS_DIR / "data" / "include_exclude_tests"
2476         stdin_filename = str(path / "b/exclude/a.py")
2477         assert_collected_sources(
2478             src=["-"],
2479             expected=[],
2480             force_exclude=r"/exclude/|a\.py",
2481             stdin_filename=stdin_filename,
2482         )
2483
2484
2485 class TestDeFactoAPI:
2486     """Test that certain symbols that are commonly used externally keep working.
2487
2488     We don't (yet) formally expose an API (see issue #779), but we should endeavor to
2489     keep certain functions that external users commonly rely on working.
2490
2491     """
2492
2493     def test_format_str(self) -> None:
2494         # format_str and Mode should keep working
2495         assert (
2496             black.format_str("print('hello')", mode=black.Mode()) == 'print("hello")\n'
2497         )
2498
2499         # you can pass line length
2500         assert (
2501             black.format_str("print('hello')", mode=black.Mode(line_length=42))
2502             == 'print("hello")\n'
2503         )
2504
2505         # invalid input raises InvalidInput
2506         with pytest.raises(black.InvalidInput):
2507             black.format_str("syntax error", mode=black.Mode())
2508
2509     def test_format_file_contents(self) -> None:
2510         # You probably should be using format_str() instead, but let's keep
2511         # this one around since people do use it
2512         assert (
2513             black.format_file_contents("x=1", fast=True, mode=black.Mode()) == "x = 1\n"
2514         )
2515
2516         with pytest.raises(black.NothingChanged):
2517             black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())
2518
2519
2520 try:
2521     with open(black.__file__, "r", encoding="utf-8") as _bf:
2522         black_source_lines = _bf.readlines()
2523 except UnicodeDecodeError:
2524     if not black.COMPILED:
2525         raise
2526
2527
2528 def tracefunc(
2529     frame: types.FrameType, event: str, arg: Any
2530 ) -> Callable[[types.FrameType, str, Any], Any]:
2531     """Show function calls `from black/__init__.py` as they happen.
2532
2533     Register this with `sys.settrace()` in a test you're debugging.
2534     """
2535     if event != "call":
2536         return tracefunc
2537
2538     stack = len(inspect.stack()) - 19
2539     stack *= 2
2540     filename = frame.f_code.co_filename
2541     lineno = frame.f_lineno
2542     func_sig_lineno = lineno - 1
2543     funcname = black_source_lines[func_sig_lineno].strip()
2544     while funcname.startswith("@"):
2545         func_sig_lineno += 1
2546         funcname = black_source_lines[func_sig_lineno].strip()
2547     if "black/__init__.py" in filename:
2548         print(f"{' ' * stack}{lineno}:{funcname}")
2549     return tracefunc