]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

96e6f1e6945d8fa1489b244f03fe3436a1522631
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     TypeVar,
29     Union,
30 )
31 from unittest.mock import MagicMock, patch
32
33 import click
34 import pytest
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     get_case_path,
63     read_data,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_detect_debug_f_strings(self) -> None:
314         root = black.lib2to3_parse("""f"{x=}" """)
315         features = black.get_features_used(root)
316         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
317         versions = black.detect_target_versions(root)
318         self.assertIn(black.TargetVersion.PY38, versions)
319
320         root = black.lib2to3_parse(
321             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
322         )
323         features = black.get_features_used(root)
324         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
325
326         # We don't yet support feature version detection in nested f-strings
327         root = black.lib2to3_parse(
328             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
329         )
330         features = black.get_features_used(root)
331         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_string_quotes(self) -> None:
335         source, expected = read_data("miscellaneous", "string_quotes")
336         mode = black.Mode(preview=True)
337         assert_format(source, expected, mode)
338         mode = replace(mode, string_normalization=False)
339         not_normalized = fs(source, mode=mode)
340         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
341         black.assert_equivalent(source, not_normalized)
342         black.assert_stable(source, not_normalized, mode=mode)
343
344     def test_skip_magic_trailing_comma(self) -> None:
345         source, _ = read_data("simple_cases", "expression")
346         expected, _ = read_data(
347             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
348         )
349         tmp_file = Path(black.dump_to_file(source))
350         diff_header = re.compile(
351             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
352             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
353         )
354         try:
355             result = BlackRunner().invoke(
356                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
357             )
358             self.assertEqual(result.exit_code, 0)
359         finally:
360             os.unlink(tmp_file)
361         actual = result.output
362         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
363         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
364         if expected != actual:
365             dump = black.dump_to_file(actual)
366             msg = (
367                 "Expected diff isn't equal to the actual. If you made changes to"
368                 " expression.py and this is an anticipated difference, overwrite"
369                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
370             )
371             self.assertEqual(expected, actual, msg)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_async_as_identifier(self) -> None:
375         source_path = get_case_path("miscellaneous", "async_as_identifier")
376         source, expected = read_data_from_file(source_path)
377         actual = fs(source)
378         self.assertFormatEqual(expected, actual)
379         major, minor = sys.version_info[:2]
380         if major < 3 or (major <= 3 and minor < 7):
381             black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, DEFAULT_MODE)
383         # ensure black can parse this when the target is 3.6
384         self.invokeBlack([str(source_path), "--target-version", "py36"])
385         # but not on 3.7, because async/await is no longer an identifier
386         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_python37(self) -> None:
390         source_path = get_case_path("py_37", "python37")
391         source, expected = read_data_from_file(source_path)
392         actual = fs(source)
393         self.assertFormatEqual(expected, actual)
394         major, minor = sys.version_info[:2]
395         if major > 3 or (major == 3 and minor >= 7):
396             black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, DEFAULT_MODE)
398         # ensure black can parse this when the target is 3.7
399         self.invokeBlack([str(source_path), "--target-version", "py37"])
400         # but not on 3.6, because we use async as a reserved keyword
401         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
402
403     def test_tab_comment_indentation(self) -> None:
404         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
405         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
406         self.assertFormatEqual(contents_spc, fs(contents_spc))
407         self.assertFormatEqual(contents_spc, fs(contents_tab))
408
409         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
410         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
411         self.assertFormatEqual(contents_spc, fs(contents_spc))
412         self.assertFormatEqual(contents_spc, fs(contents_tab))
413
414         # mixed tabs and spaces (valid Python 2 code)
415         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
416         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
417         self.assertFormatEqual(contents_spc, fs(contents_spc))
418         self.assertFormatEqual(contents_spc, fs(contents_tab))
419
420         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
421         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
422         self.assertFormatEqual(contents_spc, fs(contents_spc))
423         self.assertFormatEqual(contents_spc, fs(contents_tab))
424
425     def test_report_verbose(self) -> None:
426         report = Report(verbose=True)
427         out_lines = []
428         err_lines = []
429
430         def out(msg: str, **kwargs: Any) -> None:
431             out_lines.append(msg)
432
433         def err(msg: str, **kwargs: Any) -> None:
434             err_lines.append(msg)
435
436         with patch("black.output._out", out), patch("black.output._err", err):
437             report.done(Path("f1"), black.Changed.NO)
438             self.assertEqual(len(out_lines), 1)
439             self.assertEqual(len(err_lines), 0)
440             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
441             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
442             self.assertEqual(report.return_code, 0)
443             report.done(Path("f2"), black.Changed.YES)
444             self.assertEqual(len(out_lines), 2)
445             self.assertEqual(len(err_lines), 0)
446             self.assertEqual(out_lines[-1], "reformatted f2")
447             self.assertEqual(
448                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
449             )
450             report.done(Path("f3"), black.Changed.CACHED)
451             self.assertEqual(len(out_lines), 3)
452             self.assertEqual(len(err_lines), 0)
453             self.assertEqual(
454                 out_lines[-1], "f3 wasn't modified on disk since last run."
455             )
456             self.assertEqual(
457                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
458             )
459             self.assertEqual(report.return_code, 0)
460             report.check = True
461             self.assertEqual(report.return_code, 1)
462             report.check = False
463             report.failed(Path("e1"), "boom")
464             self.assertEqual(len(out_lines), 3)
465             self.assertEqual(len(err_lines), 1)
466             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
467             self.assertEqual(
468                 unstyle(str(report)),
469                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
470                 " reformat.",
471             )
472             self.assertEqual(report.return_code, 123)
473             report.done(Path("f3"), black.Changed.YES)
474             self.assertEqual(len(out_lines), 4)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(out_lines[-1], "reformatted f3")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.failed(Path("e2"), "boom")
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 2)
486             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.path_ignored(Path("wat"), "no match")
494             self.assertEqual(len(out_lines), 5)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(out_lines[-1], "wat ignored: no match")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.done(Path("f4"), black.Changed.NO)
504             self.assertEqual(len(out_lines), 6)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.check = True
514             self.assertEqual(
515                 unstyle(str(report)),
516                 "2 files would be reformatted, 3 files would be left unchanged, 2"
517                 " files would fail to reformat.",
518             )
519             report.check = False
520             report.diff = True
521             self.assertEqual(
522                 unstyle(str(report)),
523                 "2 files would be reformatted, 3 files would be left unchanged, 2"
524                 " files would fail to reformat.",
525             )
526
527     def test_report_quiet(self) -> None:
528         report = Report(quiet=True)
529         out_lines = []
530         err_lines = []
531
532         def out(msg: str, **kwargs: Any) -> None:
533             out_lines.append(msg)
534
535         def err(msg: str, **kwargs: Any) -> None:
536             err_lines.append(msg)
537
538         with patch("black.output._out", out), patch("black.output._err", err):
539             report.done(Path("f1"), black.Changed.NO)
540             self.assertEqual(len(out_lines), 0)
541             self.assertEqual(len(err_lines), 0)
542             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
543             self.assertEqual(report.return_code, 0)
544             report.done(Path("f2"), black.Changed.YES)
545             self.assertEqual(len(out_lines), 0)
546             self.assertEqual(len(err_lines), 0)
547             self.assertEqual(
548                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
549             )
550             report.done(Path("f3"), black.Changed.CACHED)
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 0)
553             self.assertEqual(
554                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
555             )
556             self.assertEqual(report.return_code, 0)
557             report.check = True
558             self.assertEqual(report.return_code, 1)
559             report.check = False
560             report.failed(Path("e1"), "boom")
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 1)
563             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
564             self.assertEqual(
565                 unstyle(str(report)),
566                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
567                 " reformat.",
568             )
569             self.assertEqual(report.return_code, 123)
570             report.done(Path("f3"), black.Changed.YES)
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(
574                 unstyle(str(report)),
575                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
576                 " reformat.",
577             )
578             self.assertEqual(report.return_code, 123)
579             report.failed(Path("e2"), "boom")
580             self.assertEqual(len(out_lines), 0)
581             self.assertEqual(len(err_lines), 2)
582             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.path_ignored(Path("wat"), "no match")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(
593                 unstyle(str(report)),
594                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
595                 " reformat.",
596             )
597             self.assertEqual(report.return_code, 123)
598             report.done(Path("f4"), black.Changed.NO)
599             self.assertEqual(len(out_lines), 0)
600             self.assertEqual(len(err_lines), 2)
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
604                 " reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.check = True
608             self.assertEqual(
609                 unstyle(str(report)),
610                 "2 files would be reformatted, 3 files would be left unchanged, 2"
611                 " files would fail to reformat.",
612             )
613             report.check = False
614             report.diff = True
615             self.assertEqual(
616                 unstyle(str(report)),
617                 "2 files would be reformatted, 3 files would be left unchanged, 2"
618                 " files would fail to reformat.",
619             )
620
621     def test_report_normal(self) -> None:
622         report = black.Report()
623         out_lines = []
624         err_lines = []
625
626         def out(msg: str, **kwargs: Any) -> None:
627             out_lines.append(msg)
628
629         def err(msg: str, **kwargs: Any) -> None:
630             err_lines.append(msg)
631
632         with patch("black.output._out", out), patch("black.output._err", err):
633             report.done(Path("f1"), black.Changed.NO)
634             self.assertEqual(len(out_lines), 0)
635             self.assertEqual(len(err_lines), 0)
636             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
637             self.assertEqual(report.return_code, 0)
638             report.done(Path("f2"), black.Changed.YES)
639             self.assertEqual(len(out_lines), 1)
640             self.assertEqual(len(err_lines), 0)
641             self.assertEqual(out_lines[-1], "reformatted f2")
642             self.assertEqual(
643                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
644             )
645             report.done(Path("f3"), black.Changed.CACHED)
646             self.assertEqual(len(out_lines), 1)
647             self.assertEqual(len(err_lines), 0)
648             self.assertEqual(out_lines[-1], "reformatted f2")
649             self.assertEqual(
650                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
651             )
652             self.assertEqual(report.return_code, 0)
653             report.check = True
654             self.assertEqual(report.return_code, 1)
655             report.check = False
656             report.failed(Path("e1"), "boom")
657             self.assertEqual(len(out_lines), 1)
658             self.assertEqual(len(err_lines), 1)
659             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.done(Path("f3"), black.Changed.YES)
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(out_lines[-1], "reformatted f3")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.failed(Path("e2"), "boom")
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 2)
679             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.path_ignored(Path("wat"), "no match")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(
690                 unstyle(str(report)),
691                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
692                 " reformat.",
693             )
694             self.assertEqual(report.return_code, 123)
695             report.done(Path("f4"), black.Changed.NO)
696             self.assertEqual(len(out_lines), 2)
697             self.assertEqual(len(err_lines), 2)
698             self.assertEqual(
699                 unstyle(str(report)),
700                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
701                 " reformat.",
702             )
703             self.assertEqual(report.return_code, 123)
704             report.check = True
705             self.assertEqual(
706                 unstyle(str(report)),
707                 "2 files would be reformatted, 3 files would be left unchanged, 2"
708                 " files would fail to reformat.",
709             )
710             report.check = False
711             report.diff = True
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files would be reformatted, 3 files would be left unchanged, 2"
715                 " files would fail to reformat.",
716             )
717
718     def test_lib2to3_parse(self) -> None:
719         with self.assertRaises(black.InvalidInput):
720             black.lib2to3_parse("invalid syntax")
721
722         straddling = "x + y"
723         black.lib2to3_parse(straddling)
724         black.lib2to3_parse(straddling, {TargetVersion.PY36})
725
726         py2_only = "print x"
727         with self.assertRaises(black.InvalidInput):
728             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
729
730         py3_only = "exec(x, end=y)"
731         black.lib2to3_parse(py3_only)
732         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
733
734     def test_get_features_used_decorator(self) -> None:
735         # Test the feature detection of new decorator syntax
736         # since this makes some test cases of test_get_features_used()
737         # fails if it fails, this is tested first so that a useful case
738         # is identified
739         simples, relaxed = read_data("miscellaneous", "decorators")
740         # skip explanation comments at the top of the file
741         for simple_test in simples.split("##")[1:]:
742             node = black.lib2to3_parse(simple_test)
743             decorator = str(node.children[0].children[0]).strip()
744             self.assertNotIn(
745                 Feature.RELAXED_DECORATORS,
746                 black.get_features_used(node),
747                 msg=(
748                     f"decorator '{decorator}' follows python<=3.8 syntax"
749                     "but is detected as 3.9+"
750                     # f"The full node is\n{node!r}"
751                 ),
752             )
753         # skip the '# output' comment at the top of the output part
754         for relaxed_test in relaxed.split("##")[1:]:
755             node = black.lib2to3_parse(relaxed_test)
756             decorator = str(node.children[0].children[0]).strip()
757             self.assertIn(
758                 Feature.RELAXED_DECORATORS,
759                 black.get_features_used(node),
760                 msg=(
761                     f"decorator '{decorator}' uses python3.9+ syntax"
762                     "but is detected as python<=3.8"
763                     # f"The full node is\n{node!r}"
764                 ),
765             )
766
767     def test_get_features_used(self) -> None:
768         node = black.lib2to3_parse("def f(*, arg): ...\n")
769         self.assertEqual(black.get_features_used(node), set())
770         node = black.lib2to3_parse("def f(*, arg,): ...\n")
771         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
772         node = black.lib2to3_parse("f(*arg,)\n")
773         self.assertEqual(
774             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
775         )
776         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
777         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
778         node = black.lib2to3_parse("123_456\n")
779         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
780         node = black.lib2to3_parse("123456\n")
781         self.assertEqual(black.get_features_used(node), set())
782         source, expected = read_data("simple_cases", "function")
783         node = black.lib2to3_parse(source)
784         expected_features = {
785             Feature.TRAILING_COMMA_IN_CALL,
786             Feature.TRAILING_COMMA_IN_DEF,
787             Feature.F_STRINGS,
788         }
789         self.assertEqual(black.get_features_used(node), expected_features)
790         node = black.lib2to3_parse(expected)
791         self.assertEqual(black.get_features_used(node), expected_features)
792         source, expected = read_data("simple_cases", "expression")
793         node = black.lib2to3_parse(source)
794         self.assertEqual(black.get_features_used(node), set())
795         node = black.lib2to3_parse(expected)
796         self.assertEqual(black.get_features_used(node), set())
797         node = black.lib2to3_parse("lambda a, /, b: ...")
798         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
799         node = black.lib2to3_parse("def fn(a, /, b): ...")
800         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
801         node = black.lib2to3_parse("def fn(): yield a, b")
802         self.assertEqual(black.get_features_used(node), set())
803         node = black.lib2to3_parse("def fn(): return a, b")
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse("def fn(): yield *b, c")
806         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
807         node = black.lib2to3_parse("def fn(): return a, *b, c")
808         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
809         node = black.lib2to3_parse("x = a, *b, c")
810         self.assertEqual(black.get_features_used(node), set())
811         node = black.lib2to3_parse("x: Any = regular")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("x: Any = (regular, regular)")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
816         self.assertEqual(black.get_features_used(node), set())
817         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
818         self.assertEqual(
819             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
820         )
821         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
826         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
827         node = black.lib2to3_parse("a[*b]")
828         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
829         node = black.lib2to3_parse("a[x, *y(), z] = t")
830         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
831         node = black.lib2to3_parse("def fn(*args: *T): pass")
832         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
833
834     def test_get_features_used_for_future_flags(self) -> None:
835         for src, features in [
836             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
837             (
838                 "from __future__ import (other, annotations)",
839                 {Feature.FUTURE_ANNOTATIONS},
840             ),
841             ("a = 1 + 2\nfrom something import annotations", set()),
842             ("from __future__ import x, y", set()),
843         ]:
844             with self.subTest(src=src, features=features):
845                 node = black.lib2to3_parse(src)
846                 future_imports = black.get_future_imports(node)
847                 self.assertEqual(
848                     black.get_features_used(node, future_imports=future_imports),
849                     features,
850                 )
851
852     def test_get_future_imports(self) -> None:
853         node = black.lib2to3_parse("\n")
854         self.assertEqual(set(), black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import black\n")
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
858         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
860         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import multiple\nfrom __future__ import imports\n"
863         )
864         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
865         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
868         self.assertEqual({"black"}, black.get_future_imports(node))
869         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse("from some.module import black\n")
872         self.assertEqual(set(), black.get_future_imports(node))
873         node = black.lib2to3_parse(
874             "from __future__ import unicode_literals as _unicode_literals"
875         )
876         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
877         node = black.lib2to3_parse(
878             "from __future__ import unicode_literals as _lol, print"
879         )
880         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
881
882     @pytest.mark.incompatible_with_mypyc
883     def test_debug_visitor(self) -> None:
884         source, _ = read_data("miscellaneous", "debug_visitor")
885         expected, _ = read_data("miscellaneous", "debug_visitor.out")
886         out_lines = []
887         err_lines = []
888
889         def out(msg: str, **kwargs: Any) -> None:
890             out_lines.append(msg)
891
892         def err(msg: str, **kwargs: Any) -> None:
893             err_lines.append(msg)
894
895         with patch("black.debug.out", out):
896             DebugVisitor.show(source)
897         actual = "\n".join(out_lines) + "\n"
898         log_name = ""
899         if expected != actual:
900             log_name = black.dump_to_file(*out_lines)
901         self.assertEqual(
902             expected,
903             actual,
904             f"AST print out is different. Actual version dumped to {log_name}",
905         )
906
907     def test_format_file_contents(self) -> None:
908         empty = ""
909         mode = DEFAULT_MODE
910         with self.assertRaises(black.NothingChanged):
911             black.format_file_contents(empty, mode=mode, fast=False)
912         just_nl = "\n"
913         with self.assertRaises(black.NothingChanged):
914             black.format_file_contents(just_nl, mode=mode, fast=False)
915         same = "j = [1, 2, 3]\n"
916         with self.assertRaises(black.NothingChanged):
917             black.format_file_contents(same, mode=mode, fast=False)
918         different = "j = [1,2,3]"
919         expected = same
920         actual = black.format_file_contents(different, mode=mode, fast=False)
921         self.assertEqual(expected, actual)
922         invalid = "return if you can"
923         with self.assertRaises(black.InvalidInput) as e:
924             black.format_file_contents(invalid, mode=mode, fast=False)
925         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
926
927     def test_endmarker(self) -> None:
928         n = black.lib2to3_parse("\n")
929         self.assertEqual(n.type, black.syms.file_input)
930         self.assertEqual(len(n.children), 1)
931         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
932
933     @pytest.mark.incompatible_with_mypyc
934     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
935     def test_assertFormatEqual(self) -> None:
936         out_lines = []
937         err_lines = []
938
939         def out(msg: str, **kwargs: Any) -> None:
940             out_lines.append(msg)
941
942         def err(msg: str, **kwargs: Any) -> None:
943             err_lines.append(msg)
944
945         with patch("black.output._out", out), patch("black.output._err", err):
946             with self.assertRaises(AssertionError):
947                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
948
949         out_str = "".join(out_lines)
950         self.assertIn("Expected tree:", out_str)
951         self.assertIn("Actual tree:", out_str)
952         self.assertEqual("".join(err_lines), "")
953
954     @event_loop()
955     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
956     def test_works_in_mono_process_only_environment(self) -> None:
957         with cache_dir() as workspace:
958             for f in [
959                 (workspace / "one.py").resolve(),
960                 (workspace / "two.py").resolve(),
961             ]:
962                 f.write_text('print("hello")\n')
963             self.invokeBlack([str(workspace)])
964
965     @event_loop()
966     def test_check_diff_use_together(self) -> None:
967         with cache_dir():
968             # Files which will be reformatted.
969             src1 = get_case_path("miscellaneous", "string_quotes")
970             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
971             # Files which will not be reformatted.
972             src2 = get_case_path("simple_cases", "composition")
973             self.invokeBlack([str(src2), "--diff", "--check"])
974             # Multi file command.
975             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
976
977     def test_no_src_fails(self) -> None:
978         with cache_dir():
979             self.invokeBlack([], exit_code=1)
980
981     def test_src_and_code_fails(self) -> None:
982         with cache_dir():
983             self.invokeBlack([".", "-c", "0"], exit_code=1)
984
985     def test_broken_symlink(self) -> None:
986         with cache_dir() as workspace:
987             symlink = workspace / "broken_link.py"
988             try:
989                 symlink.symlink_to("nonexistent.py")
990             except (OSError, NotImplementedError) as e:
991                 self.skipTest(f"Can't create symlinks: {e}")
992             self.invokeBlack([str(workspace.resolve())])
993
994     def test_single_file_force_pyi(self) -> None:
995         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
996         contents, expected = read_data("miscellaneous", "force_pyi")
997         with cache_dir() as workspace:
998             path = (workspace / "file.py").resolve()
999             with open(path, "w") as fh:
1000                 fh.write(contents)
1001             self.invokeBlack([str(path), "--pyi"])
1002             with open(path, "r") as fh:
1003                 actual = fh.read()
1004             # verify cache with --pyi is separate
1005             pyi_cache = black.read_cache(pyi_mode)
1006             self.assertIn(str(path), pyi_cache)
1007             normal_cache = black.read_cache(DEFAULT_MODE)
1008             self.assertNotIn(str(path), normal_cache)
1009         self.assertFormatEqual(expected, actual)
1010         black.assert_equivalent(contents, actual)
1011         black.assert_stable(contents, actual, pyi_mode)
1012
1013     @event_loop()
1014     def test_multi_file_force_pyi(self) -> None:
1015         reg_mode = DEFAULT_MODE
1016         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1017         contents, expected = read_data("miscellaneous", "force_pyi")
1018         with cache_dir() as workspace:
1019             paths = [
1020                 (workspace / "file1.py").resolve(),
1021                 (workspace / "file2.py").resolve(),
1022             ]
1023             for path in paths:
1024                 with open(path, "w") as fh:
1025                     fh.write(contents)
1026             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1027             for path in paths:
1028                 with open(path, "r") as fh:
1029                     actual = fh.read()
1030                 self.assertEqual(actual, expected)
1031             # verify cache with --pyi is separate
1032             pyi_cache = black.read_cache(pyi_mode)
1033             normal_cache = black.read_cache(reg_mode)
1034             for path in paths:
1035                 self.assertIn(str(path), pyi_cache)
1036                 self.assertNotIn(str(path), normal_cache)
1037
1038     def test_pipe_force_pyi(self) -> None:
1039         source, expected = read_data("miscellaneous", "force_pyi")
1040         result = CliRunner().invoke(
1041             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1042         )
1043         self.assertEqual(result.exit_code, 0)
1044         actual = result.output
1045         self.assertFormatEqual(actual, expected)
1046
1047     def test_single_file_force_py36(self) -> None:
1048         reg_mode = DEFAULT_MODE
1049         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1050         source, expected = read_data("miscellaneous", "force_py36")
1051         with cache_dir() as workspace:
1052             path = (workspace / "file.py").resolve()
1053             with open(path, "w") as fh:
1054                 fh.write(source)
1055             self.invokeBlack([str(path), *PY36_ARGS])
1056             with open(path, "r") as fh:
1057                 actual = fh.read()
1058             # verify cache with --target-version is separate
1059             py36_cache = black.read_cache(py36_mode)
1060             self.assertIn(str(path), py36_cache)
1061             normal_cache = black.read_cache(reg_mode)
1062             self.assertNotIn(str(path), normal_cache)
1063         self.assertEqual(actual, expected)
1064
1065     @event_loop()
1066     def test_multi_file_force_py36(self) -> None:
1067         reg_mode = DEFAULT_MODE
1068         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1069         source, expected = read_data("miscellaneous", "force_py36")
1070         with cache_dir() as workspace:
1071             paths = [
1072                 (workspace / "file1.py").resolve(),
1073                 (workspace / "file2.py").resolve(),
1074             ]
1075             for path in paths:
1076                 with open(path, "w") as fh:
1077                     fh.write(source)
1078             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1079             for path in paths:
1080                 with open(path, "r") as fh:
1081                     actual = fh.read()
1082                 self.assertEqual(actual, expected)
1083             # verify cache with --target-version is separate
1084             pyi_cache = black.read_cache(py36_mode)
1085             normal_cache = black.read_cache(reg_mode)
1086             for path in paths:
1087                 self.assertIn(str(path), pyi_cache)
1088                 self.assertNotIn(str(path), normal_cache)
1089
1090     def test_pipe_force_py36(self) -> None:
1091         source, expected = read_data("miscellaneous", "force_py36")
1092         result = CliRunner().invoke(
1093             black.main,
1094             ["-", "-q", "--target-version=py36"],
1095             input=BytesIO(source.encode("utf8")),
1096         )
1097         self.assertEqual(result.exit_code, 0)
1098         actual = result.output
1099         self.assertFormatEqual(actual, expected)
1100
1101     @pytest.mark.incompatible_with_mypyc
1102     def test_reformat_one_with_stdin(self) -> None:
1103         with patch(
1104             "black.format_stdin_to_stdout",
1105             return_value=lambda *args, **kwargs: black.Changed.YES,
1106         ) as fsts:
1107             report = MagicMock()
1108             path = Path("-")
1109             black.reformat_one(
1110                 path,
1111                 fast=True,
1112                 write_back=black.WriteBack.YES,
1113                 mode=DEFAULT_MODE,
1114                 report=report,
1115             )
1116             fsts.assert_called_once()
1117             report.done.assert_called_with(path, black.Changed.YES)
1118
1119     @pytest.mark.incompatible_with_mypyc
1120     def test_reformat_one_with_stdin_filename(self) -> None:
1121         with patch(
1122             "black.format_stdin_to_stdout",
1123             return_value=lambda *args, **kwargs: black.Changed.YES,
1124         ) as fsts:
1125             report = MagicMock()
1126             p = "foo.py"
1127             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1128             expected = Path(p)
1129             black.reformat_one(
1130                 path,
1131                 fast=True,
1132                 write_back=black.WriteBack.YES,
1133                 mode=DEFAULT_MODE,
1134                 report=report,
1135             )
1136             fsts.assert_called_once_with(
1137                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1138             )
1139             # __BLACK_STDIN_FILENAME__ should have been stripped
1140             report.done.assert_called_with(expected, black.Changed.YES)
1141
1142     @pytest.mark.incompatible_with_mypyc
1143     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1144         with patch(
1145             "black.format_stdin_to_stdout",
1146             return_value=lambda *args, **kwargs: black.Changed.YES,
1147         ) as fsts:
1148             report = MagicMock()
1149             p = "foo.pyi"
1150             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1151             expected = Path(p)
1152             black.reformat_one(
1153                 path,
1154                 fast=True,
1155                 write_back=black.WriteBack.YES,
1156                 mode=DEFAULT_MODE,
1157                 report=report,
1158             )
1159             fsts.assert_called_once_with(
1160                 fast=True,
1161                 write_back=black.WriteBack.YES,
1162                 mode=replace(DEFAULT_MODE, is_pyi=True),
1163             )
1164             # __BLACK_STDIN_FILENAME__ should have been stripped
1165             report.done.assert_called_with(expected, black.Changed.YES)
1166
1167     @pytest.mark.incompatible_with_mypyc
1168     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1169         with patch(
1170             "black.format_stdin_to_stdout",
1171             return_value=lambda *args, **kwargs: black.Changed.YES,
1172         ) as fsts:
1173             report = MagicMock()
1174             p = "foo.ipynb"
1175             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1176             expected = Path(p)
1177             black.reformat_one(
1178                 path,
1179                 fast=True,
1180                 write_back=black.WriteBack.YES,
1181                 mode=DEFAULT_MODE,
1182                 report=report,
1183             )
1184             fsts.assert_called_once_with(
1185                 fast=True,
1186                 write_back=black.WriteBack.YES,
1187                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1188             )
1189             # __BLACK_STDIN_FILENAME__ should have been stripped
1190             report.done.assert_called_with(expected, black.Changed.YES)
1191
1192     @pytest.mark.incompatible_with_mypyc
1193     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1194         with patch(
1195             "black.format_stdin_to_stdout",
1196             return_value=lambda *args, **kwargs: black.Changed.YES,
1197         ) as fsts:
1198             report = MagicMock()
1199             # Even with an existing file, since we are forcing stdin, black
1200             # should output to stdout and not modify the file inplace
1201             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1202             # Make sure is_file actually returns True
1203             self.assertTrue(p.is_file())
1204             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1205             expected = Path(p)
1206             black.reformat_one(
1207                 path,
1208                 fast=True,
1209                 write_back=black.WriteBack.YES,
1210                 mode=DEFAULT_MODE,
1211                 report=report,
1212             )
1213             fsts.assert_called_once()
1214             # __BLACK_STDIN_FILENAME__ should have been stripped
1215             report.done.assert_called_with(expected, black.Changed.YES)
1216
1217     def test_reformat_one_with_stdin_empty(self) -> None:
1218         output = io.StringIO()
1219         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1220             try:
1221                 black.format_stdin_to_stdout(
1222                     fast=True,
1223                     content="",
1224                     write_back=black.WriteBack.YES,
1225                     mode=DEFAULT_MODE,
1226                 )
1227             except io.UnsupportedOperation:
1228                 pass  # StringIO does not support detach
1229             assert output.getvalue() == ""
1230
1231     def test_invalid_cli_regex(self) -> None:
1232         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1233             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1234
1235     def test_required_version_matches_version(self) -> None:
1236         self.invokeBlack(
1237             ["--required-version", black.__version__, "-c", "0"],
1238             exit_code=0,
1239             ignore_config=True,
1240         )
1241
1242     def test_required_version_matches_partial_version(self) -> None:
1243         self.invokeBlack(
1244             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1245             exit_code=0,
1246             ignore_config=True,
1247         )
1248
1249     def test_required_version_does_not_match_on_minor_version(self) -> None:
1250         self.invokeBlack(
1251             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1252             exit_code=1,
1253             ignore_config=True,
1254         )
1255
1256     def test_required_version_does_not_match_version(self) -> None:
1257         result = BlackRunner().invoke(
1258             black.main,
1259             ["--required-version", "20.99b", "-c", "0"],
1260         )
1261         self.assertEqual(result.exit_code, 1)
1262         self.assertIn("required version", result.stderr)
1263
1264     def test_preserves_line_endings(self) -> None:
1265         with TemporaryDirectory() as workspace:
1266             test_file = Path(workspace) / "test.py"
1267             for nl in ["\n", "\r\n"]:
1268                 contents = nl.join(["def f(  ):", "    pass"])
1269                 test_file.write_bytes(contents.encode())
1270                 ff(test_file, write_back=black.WriteBack.YES)
1271                 updated_contents: bytes = test_file.read_bytes()
1272                 self.assertIn(nl.encode(), updated_contents)
1273                 if nl == "\n":
1274                     self.assertNotIn(b"\r\n", updated_contents)
1275
1276     def test_preserves_line_endings_via_stdin(self) -> None:
1277         for nl in ["\n", "\r\n"]:
1278             contents = nl.join(["def f(  ):", "    pass"])
1279             runner = BlackRunner()
1280             result = runner.invoke(
1281                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1282             )
1283             self.assertEqual(result.exit_code, 0)
1284             output = result.stdout_bytes
1285             self.assertIn(nl.encode("utf8"), output)
1286             if nl == "\n":
1287                 self.assertNotIn(b"\r\n", output)
1288
1289     def test_normalize_line_endings(self) -> None:
1290         with TemporaryDirectory() as workspace:
1291             test_file = Path(workspace) / "test.py"
1292             for data, expected in (
1293                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1294                 (b"l\nl\r\n ", b"l\nl\n"),
1295             ):
1296                 test_file.write_bytes(data)
1297                 ff(test_file, write_back=black.WriteBack.YES)
1298                 self.assertEqual(test_file.read_bytes(), expected)
1299
1300     def test_assert_equivalent_different_asts(self) -> None:
1301         with self.assertRaises(AssertionError):
1302             black.assert_equivalent("{}", "None")
1303
1304     def test_shhh_click(self) -> None:
1305         try:
1306             from click import _unicodefun  # type: ignore
1307         except ImportError:
1308             self.skipTest("Incompatible Click version")
1309
1310         if not hasattr(_unicodefun, "_verify_python_env"):
1311             self.skipTest("Incompatible Click version")
1312
1313         # First, let's see if Click is crashing with a preferred ASCII charset.
1314         with patch("locale.getpreferredencoding") as gpe:
1315             gpe.return_value = "ASCII"
1316             with self.assertRaises(RuntimeError):
1317                 _unicodefun._verify_python_env()
1318         # Now, let's silence Click...
1319         black.patch_click()
1320         # ...and confirm it's silent.
1321         with patch("locale.getpreferredencoding") as gpe:
1322             gpe.return_value = "ASCII"
1323             try:
1324                 _unicodefun._verify_python_env()
1325             except RuntimeError as re:
1326                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1327
1328     def test_root_logger_not_used_directly(self) -> None:
1329         def fail(*args: Any, **kwargs: Any) -> None:
1330             self.fail("Record created with root logger")
1331
1332         with patch.multiple(
1333             logging.root,
1334             debug=fail,
1335             info=fail,
1336             warning=fail,
1337             error=fail,
1338             critical=fail,
1339             log=fail,
1340         ):
1341             ff(THIS_DIR / "util.py")
1342
1343     def test_invalid_config_return_code(self) -> None:
1344         tmp_file = Path(black.dump_to_file())
1345         try:
1346             tmp_config = Path(black.dump_to_file())
1347             tmp_config.unlink()
1348             args = ["--config", str(tmp_config), str(tmp_file)]
1349             self.invokeBlack(args, exit_code=2, ignore_config=False)
1350         finally:
1351             tmp_file.unlink()
1352
1353     def test_parse_pyproject_toml(self) -> None:
1354         test_toml_file = THIS_DIR / "test.toml"
1355         config = black.parse_pyproject_toml(str(test_toml_file))
1356         self.assertEqual(config["verbose"], 1)
1357         self.assertEqual(config["check"], "no")
1358         self.assertEqual(config["diff"], "y")
1359         self.assertEqual(config["color"], True)
1360         self.assertEqual(config["line_length"], 79)
1361         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1362         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1363         self.assertEqual(config["exclude"], r"\.pyi?$")
1364         self.assertEqual(config["include"], r"\.py?$")
1365
1366     def test_read_pyproject_toml(self) -> None:
1367         test_toml_file = THIS_DIR / "test.toml"
1368         fake_ctx = FakeContext()
1369         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1370         config = fake_ctx.default_map
1371         self.assertEqual(config["verbose"], "1")
1372         self.assertEqual(config["check"], "no")
1373         self.assertEqual(config["diff"], "y")
1374         self.assertEqual(config["color"], "True")
1375         self.assertEqual(config["line_length"], "79")
1376         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1377         self.assertEqual(config["exclude"], r"\.pyi?$")
1378         self.assertEqual(config["include"], r"\.py?$")
1379
1380     @pytest.mark.incompatible_with_mypyc
1381     def test_find_project_root(self) -> None:
1382         with TemporaryDirectory() as workspace:
1383             root = Path(workspace)
1384             test_dir = root / "test"
1385             test_dir.mkdir()
1386
1387             src_dir = root / "src"
1388             src_dir.mkdir()
1389
1390             root_pyproject = root / "pyproject.toml"
1391             root_pyproject.touch()
1392             src_pyproject = src_dir / "pyproject.toml"
1393             src_pyproject.touch()
1394             src_python = src_dir / "foo.py"
1395             src_python.touch()
1396
1397             self.assertEqual(
1398                 black.find_project_root((src_dir, test_dir)),
1399                 (root.resolve(), "pyproject.toml"),
1400             )
1401             self.assertEqual(
1402                 black.find_project_root((src_dir,)),
1403                 (src_dir.resolve(), "pyproject.toml"),
1404             )
1405             self.assertEqual(
1406                 black.find_project_root((src_python,)),
1407                 (src_dir.resolve(), "pyproject.toml"),
1408             )
1409
1410             with change_directory(test_dir):
1411                 self.assertEqual(
1412                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1413                     (src_dir.resolve(), "pyproject.toml"),
1414                 )
1415
1416     @patch(
1417         "black.files.find_user_pyproject_toml",
1418     )
1419     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1420         find_user_pyproject_toml.side_effect = RuntimeError()
1421
1422         with redirect_stderr(io.StringIO()) as stderr:
1423             result = black.files.find_pyproject_toml(
1424                 path_search_start=(str(Path.cwd().root),)
1425             )
1426
1427         assert result is None
1428         err = stderr.getvalue()
1429         assert "Ignoring user configuration" in err
1430
1431     @patch(
1432         "black.files.find_user_pyproject_toml",
1433         black.files.find_user_pyproject_toml.__wrapped__,
1434     )
1435     def test_find_user_pyproject_toml_linux(self) -> None:
1436         if system() == "Windows":
1437             return
1438
1439         # Test if XDG_CONFIG_HOME is checked
1440         with TemporaryDirectory() as workspace:
1441             tmp_user_config = Path(workspace) / "black"
1442             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1443                 self.assertEqual(
1444                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1445                 )
1446
1447         # Test fallback for XDG_CONFIG_HOME
1448         with patch.dict("os.environ"):
1449             os.environ.pop("XDG_CONFIG_HOME", None)
1450             fallback_user_config = Path("~/.config").expanduser() / "black"
1451             self.assertEqual(
1452                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1453             )
1454
1455     def test_find_user_pyproject_toml_windows(self) -> None:
1456         if system() != "Windows":
1457             return
1458
1459         user_config_path = Path.home() / ".black"
1460         self.assertEqual(
1461             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1462         )
1463
1464     def test_bpo_33660_workaround(self) -> None:
1465         if system() == "Windows":
1466             return
1467
1468         # https://bugs.python.org/issue33660
1469         root = Path("/")
1470         with change_directory(root):
1471             path = Path("workspace") / "project"
1472             report = black.Report(verbose=True)
1473             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1474             self.assertEqual(normalized_path, "workspace/project")
1475
1476     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1477         if system() != "Windows":
1478             return
1479
1480         with TemporaryDirectory() as workspace:
1481             root = Path(workspace)
1482             junction_dir = root / "junction"
1483             junction_target_outside_of_root = root / ".."
1484             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1485
1486             report = black.Report(verbose=True)
1487             normalized_path = black.normalize_path_maybe_ignore(
1488                 junction_dir, root, report
1489             )
1490             # Manually delete for Python < 3.8
1491             os.system(f"rmdir {junction_dir}")
1492
1493             self.assertEqual(normalized_path, None)
1494
1495     def test_newline_comment_interaction(self) -> None:
1496         source = "class A:\\\r\n# type: ignore\n pass\n"
1497         output = black.format_str(source, mode=DEFAULT_MODE)
1498         black.assert_stable(source, output, mode=DEFAULT_MODE)
1499
1500     def test_bpo_2142_workaround(self) -> None:
1501         # https://bugs.python.org/issue2142
1502
1503         source, _ = read_data("miscellaneous", "missing_final_newline")
1504         # read_data adds a trailing newline
1505         source = source.rstrip()
1506         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1507         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1508         diff_header = re.compile(
1509             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1510             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1511         )
1512         try:
1513             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1514             self.assertEqual(result.exit_code, 0)
1515         finally:
1516             os.unlink(tmp_file)
1517         actual = result.output
1518         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1519         self.assertEqual(actual, expected)
1520
1521     @staticmethod
1522     def compare_results(
1523         result: click.testing.Result, expected_value: str, expected_exit_code: int
1524     ) -> None:
1525         """Helper method to test the value and exit code of a click Result."""
1526         assert (
1527             result.output == expected_value
1528         ), "The output did not match the expected value."
1529         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1530
1531     def test_code_option(self) -> None:
1532         """Test the code option with no changes."""
1533         code = 'print("Hello world")\n'
1534         args = ["--code", code]
1535         result = CliRunner().invoke(black.main, args)
1536
1537         self.compare_results(result, code, 0)
1538
1539     def test_code_option_changed(self) -> None:
1540         """Test the code option when changes are required."""
1541         code = "print('hello world')"
1542         formatted = black.format_str(code, mode=DEFAULT_MODE)
1543
1544         args = ["--code", code]
1545         result = CliRunner().invoke(black.main, args)
1546
1547         self.compare_results(result, formatted, 0)
1548
1549     def test_code_option_check(self) -> None:
1550         """Test the code option when check is passed."""
1551         args = ["--check", "--code", 'print("Hello world")\n']
1552         result = CliRunner().invoke(black.main, args)
1553         self.compare_results(result, "", 0)
1554
1555     def test_code_option_check_changed(self) -> None:
1556         """Test the code option when changes are required, and check is passed."""
1557         args = ["--check", "--code", "print('hello world')"]
1558         result = CliRunner().invoke(black.main, args)
1559         self.compare_results(result, "", 1)
1560
1561     def test_code_option_diff(self) -> None:
1562         """Test the code option when diff is passed."""
1563         code = "print('hello world')"
1564         formatted = black.format_str(code, mode=DEFAULT_MODE)
1565         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1566
1567         args = ["--diff", "--code", code]
1568         result = CliRunner().invoke(black.main, args)
1569
1570         # Remove time from diff
1571         output = DIFF_TIME.sub("", result.output)
1572
1573         assert output == result_diff, "The output did not match the expected value."
1574         assert result.exit_code == 0, "The exit code is incorrect."
1575
1576     def test_code_option_color_diff(self) -> None:
1577         """Test the code option when color and diff are passed."""
1578         code = "print('hello world')"
1579         formatted = black.format_str(code, mode=DEFAULT_MODE)
1580
1581         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1582         result_diff = color_diff(result_diff)
1583
1584         args = ["--diff", "--color", "--code", code]
1585         result = CliRunner().invoke(black.main, args)
1586
1587         # Remove time from diff
1588         output = DIFF_TIME.sub("", result.output)
1589
1590         assert output == result_diff, "The output did not match the expected value."
1591         assert result.exit_code == 0, "The exit code is incorrect."
1592
1593     @pytest.mark.incompatible_with_mypyc
1594     def test_code_option_safe(self) -> None:
1595         """Test that the code option throws an error when the sanity checks fail."""
1596         # Patch black.assert_equivalent to ensure the sanity checks fail
1597         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1598             code = 'print("Hello world")'
1599             error_msg = f"{code}\nerror: cannot format <string>: \n"
1600
1601             args = ["--safe", "--code", code]
1602             result = CliRunner().invoke(black.main, args)
1603
1604             self.compare_results(result, error_msg, 123)
1605
1606     def test_code_option_fast(self) -> None:
1607         """Test that the code option ignores errors when the sanity checks fail."""
1608         # Patch black.assert_equivalent to ensure the sanity checks fail
1609         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1610             code = 'print("Hello world")'
1611             formatted = black.format_str(code, mode=DEFAULT_MODE)
1612
1613             args = ["--fast", "--code", code]
1614             result = CliRunner().invoke(black.main, args)
1615
1616             self.compare_results(result, formatted, 0)
1617
1618     @pytest.mark.incompatible_with_mypyc
1619     def test_code_option_config(self) -> None:
1620         """
1621         Test that the code option finds the pyproject.toml in the current directory.
1622         """
1623         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1624             args = ["--code", "print"]
1625             # This is the only directory known to contain a pyproject.toml
1626             with change_directory(PROJECT_ROOT):
1627                 CliRunner().invoke(black.main, args)
1628                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1629
1630             assert (
1631                 len(parse.mock_calls) >= 1
1632             ), "Expected config parse to be called with the current directory."
1633
1634             _, call_args, _ = parse.mock_calls[0]
1635             assert (
1636                 call_args[0].lower() == str(pyproject_path).lower()
1637             ), "Incorrect config loaded."
1638
1639     @pytest.mark.incompatible_with_mypyc
1640     def test_code_option_parent_config(self) -> None:
1641         """
1642         Test that the code option finds the pyproject.toml in the parent directory.
1643         """
1644         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1645             with change_directory(THIS_DIR):
1646                 args = ["--code", "print"]
1647                 CliRunner().invoke(black.main, args)
1648
1649                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1650                 assert (
1651                     len(parse.mock_calls) >= 1
1652                 ), "Expected config parse to be called with the current directory."
1653
1654                 _, call_args, _ = parse.mock_calls[0]
1655                 assert (
1656                     call_args[0].lower() == str(pyproject_path).lower()
1657                 ), "Incorrect config loaded."
1658
1659     def test_for_handled_unexpected_eof_error(self) -> None:
1660         """
1661         Test that an unexpected EOF SyntaxError is nicely presented.
1662         """
1663         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1664             black.lib2to3_parse("print(", {})
1665
1666         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1667
1668     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1669         with pytest.raises(AssertionError) as err:
1670             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1671
1672         err.match("--safe")
1673         # Unfortunately the SyntaxError message has changed in newer versions so we
1674         # can't match it directly.
1675         err.match("invalid character")
1676         err.match(r"\(<unknown>, line 1\)")
1677
1678
1679 class TestCaching:
1680     def test_get_cache_dir(
1681         self,
1682         tmp_path: Path,
1683         monkeypatch: pytest.MonkeyPatch,
1684     ) -> None:
1685         # Create multiple cache directories
1686         workspace1 = tmp_path / "ws1"
1687         workspace1.mkdir()
1688         workspace2 = tmp_path / "ws2"
1689         workspace2.mkdir()
1690
1691         # Force user_cache_dir to use the temporary directory for easier assertions
1692         patch_user_cache_dir = patch(
1693             target="black.cache.user_cache_dir",
1694             autospec=True,
1695             return_value=str(workspace1),
1696         )
1697
1698         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1699         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1700         with patch_user_cache_dir:
1701             assert get_cache_dir() == workspace1
1702
1703         # If it is set, use the path provided in the env var.
1704         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1705         assert get_cache_dir() == workspace2
1706
1707     def test_cache_broken_file(self) -> None:
1708         mode = DEFAULT_MODE
1709         with cache_dir() as workspace:
1710             cache_file = get_cache_file(mode)
1711             cache_file.write_text("this is not a pickle")
1712             assert black.read_cache(mode) == {}
1713             src = (workspace / "test.py").resolve()
1714             src.write_text("print('hello')")
1715             invokeBlack([str(src)])
1716             cache = black.read_cache(mode)
1717             assert str(src) in cache
1718
1719     def test_cache_single_file_already_cached(self) -> None:
1720         mode = DEFAULT_MODE
1721         with cache_dir() as workspace:
1722             src = (workspace / "test.py").resolve()
1723             src.write_text("print('hello')")
1724             black.write_cache({}, [src], mode)
1725             invokeBlack([str(src)])
1726             assert src.read_text() == "print('hello')"
1727
1728     @event_loop()
1729     def test_cache_multiple_files(self) -> None:
1730         mode = DEFAULT_MODE
1731         with cache_dir() as workspace, patch(
1732             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1733         ):
1734             one = (workspace / "one.py").resolve()
1735             with one.open("w") as fobj:
1736                 fobj.write("print('hello')")
1737             two = (workspace / "two.py").resolve()
1738             with two.open("w") as fobj:
1739                 fobj.write("print('hello')")
1740             black.write_cache({}, [one], mode)
1741             invokeBlack([str(workspace)])
1742             with one.open("r") as fobj:
1743                 assert fobj.read() == "print('hello')"
1744             with two.open("r") as fobj:
1745                 assert fobj.read() == 'print("hello")\n'
1746             cache = black.read_cache(mode)
1747             assert str(one) in cache
1748             assert str(two) in cache
1749
1750     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1751     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1752         mode = DEFAULT_MODE
1753         with cache_dir() as workspace:
1754             src = (workspace / "test.py").resolve()
1755             with src.open("w") as fobj:
1756                 fobj.write("print('hello')")
1757             with patch("black.read_cache") as read_cache, patch(
1758                 "black.write_cache"
1759             ) as write_cache:
1760                 cmd = [str(src), "--diff"]
1761                 if color:
1762                     cmd.append("--color")
1763                 invokeBlack(cmd)
1764                 cache_file = get_cache_file(mode)
1765                 assert cache_file.exists() is False
1766                 write_cache.assert_not_called()
1767                 read_cache.assert_not_called()
1768
1769     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1770     @event_loop()
1771     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1772         with cache_dir() as workspace:
1773             for tag in range(0, 4):
1774                 src = (workspace / f"test{tag}.py").resolve()
1775                 with src.open("w") as fobj:
1776                     fobj.write("print('hello')")
1777             with patch(
1778                 "black.concurrency.Manager", wraps=multiprocessing.Manager
1779             ) as mgr:
1780                 cmd = ["--diff", str(workspace)]
1781                 if color:
1782                     cmd.append("--color")
1783                 invokeBlack(cmd, exit_code=0)
1784                 # this isn't quite doing what we want, but if it _isn't_
1785                 # called then we cannot be using the lock it provides
1786                 mgr.assert_called()
1787
1788     def test_no_cache_when_stdin(self) -> None:
1789         mode = DEFAULT_MODE
1790         with cache_dir():
1791             result = CliRunner().invoke(
1792                 black.main, ["-"], input=BytesIO(b"print('hello')")
1793             )
1794             assert not result.exit_code
1795             cache_file = get_cache_file(mode)
1796             assert not cache_file.exists()
1797
1798     def test_read_cache_no_cachefile(self) -> None:
1799         mode = DEFAULT_MODE
1800         with cache_dir():
1801             assert black.read_cache(mode) == {}
1802
1803     def test_write_cache_read_cache(self) -> None:
1804         mode = DEFAULT_MODE
1805         with cache_dir() as workspace:
1806             src = (workspace / "test.py").resolve()
1807             src.touch()
1808             black.write_cache({}, [src], mode)
1809             cache = black.read_cache(mode)
1810             assert str(src) in cache
1811             assert cache[str(src)] == black.get_cache_info(src)
1812
1813     def test_filter_cached(self) -> None:
1814         with TemporaryDirectory() as workspace:
1815             path = Path(workspace)
1816             uncached = (path / "uncached").resolve()
1817             cached = (path / "cached").resolve()
1818             cached_but_changed = (path / "changed").resolve()
1819             uncached.touch()
1820             cached.touch()
1821             cached_but_changed.touch()
1822             cache = {
1823                 str(cached): black.get_cache_info(cached),
1824                 str(cached_but_changed): (0.0, 0),
1825             }
1826             todo, done = black.cache.filter_cached(
1827                 cache, {uncached, cached, cached_but_changed}
1828             )
1829             assert todo == {uncached, cached_but_changed}
1830             assert done == {cached}
1831
1832     def test_write_cache_creates_directory_if_needed(self) -> None:
1833         mode = DEFAULT_MODE
1834         with cache_dir(exists=False) as workspace:
1835             assert not workspace.exists()
1836             black.write_cache({}, [], mode)
1837             assert workspace.exists()
1838
1839     @event_loop()
1840     def test_failed_formatting_does_not_get_cached(self) -> None:
1841         mode = DEFAULT_MODE
1842         with cache_dir() as workspace, patch(
1843             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1844         ):
1845             failing = (workspace / "failing.py").resolve()
1846             with failing.open("w") as fobj:
1847                 fobj.write("not actually python")
1848             clean = (workspace / "clean.py").resolve()
1849             with clean.open("w") as fobj:
1850                 fobj.write('print("hello")\n')
1851             invokeBlack([str(workspace)], exit_code=123)
1852             cache = black.read_cache(mode)
1853             assert str(failing) not in cache
1854             assert str(clean) in cache
1855
1856     def test_write_cache_write_fail(self) -> None:
1857         mode = DEFAULT_MODE
1858         with cache_dir(), patch.object(Path, "open") as mock:
1859             mock.side_effect = OSError
1860             black.write_cache({}, [], mode)
1861
1862     def test_read_cache_line_lengths(self) -> None:
1863         mode = DEFAULT_MODE
1864         short_mode = replace(DEFAULT_MODE, line_length=1)
1865         with cache_dir() as workspace:
1866             path = (workspace / "file.py").resolve()
1867             path.touch()
1868             black.write_cache({}, [path], mode)
1869             one = black.read_cache(mode)
1870             assert str(path) in one
1871             two = black.read_cache(short_mode)
1872             assert str(path) not in two
1873
1874
1875 def assert_collected_sources(
1876     src: Sequence[Union[str, Path]],
1877     expected: Sequence[Union[str, Path]],
1878     *,
1879     ctx: Optional[FakeContext] = None,
1880     exclude: Optional[str] = None,
1881     include: Optional[str] = None,
1882     extend_exclude: Optional[str] = None,
1883     force_exclude: Optional[str] = None,
1884     stdin_filename: Optional[str] = None,
1885 ) -> None:
1886     gs_src = tuple(str(Path(s)) for s in src)
1887     gs_expected = [Path(s) for s in expected]
1888     gs_exclude = None if exclude is None else compile_pattern(exclude)
1889     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1890     gs_extend_exclude = (
1891         None if extend_exclude is None else compile_pattern(extend_exclude)
1892     )
1893     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1894     collected = black.get_sources(
1895         ctx=ctx or FakeContext(),
1896         src=gs_src,
1897         quiet=False,
1898         verbose=False,
1899         include=gs_include,
1900         exclude=gs_exclude,
1901         extend_exclude=gs_extend_exclude,
1902         force_exclude=gs_force_exclude,
1903         report=black.Report(),
1904         stdin_filename=stdin_filename,
1905     )
1906     assert sorted(collected) == sorted(gs_expected)
1907
1908
1909 class TestFileCollection:
1910     def test_include_exclude(self) -> None:
1911         path = THIS_DIR / "data" / "include_exclude_tests"
1912         src = [path]
1913         expected = [
1914             Path(path / "b/dont_exclude/a.py"),
1915             Path(path / "b/dont_exclude/a.pyi"),
1916         ]
1917         assert_collected_sources(
1918             src,
1919             expected,
1920             include=r"\.pyi?$",
1921             exclude=r"/exclude/|/\.definitely_exclude/",
1922         )
1923
1924     def test_gitignore_used_as_default(self) -> None:
1925         base = Path(DATA_DIR / "include_exclude_tests")
1926         expected = [
1927             base / "b/.definitely_exclude/a.py",
1928             base / "b/.definitely_exclude/a.pyi",
1929         ]
1930         src = [base / "b/"]
1931         ctx = FakeContext()
1932         ctx.obj["root"] = base
1933         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1934
1935     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1936     def test_exclude_for_issue_1572(self) -> None:
1937         # Exclude shouldn't touch files that were explicitly given to Black through the
1938         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1939         # https://github.com/psf/black/issues/1572
1940         path = DATA_DIR / "include_exclude_tests"
1941         src = [path / "b/exclude/a.py"]
1942         expected = [path / "b/exclude/a.py"]
1943         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1944
1945     def test_gitignore_exclude(self) -> None:
1946         path = THIS_DIR / "data" / "include_exclude_tests"
1947         include = re.compile(r"\.pyi?$")
1948         exclude = re.compile(r"")
1949         report = black.Report()
1950         gitignore = PathSpec.from_lines(
1951             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1952         )
1953         sources: List[Path] = []
1954         expected = [
1955             Path(path / "b/dont_exclude/a.py"),
1956             Path(path / "b/dont_exclude/a.pyi"),
1957         ]
1958         this_abs = THIS_DIR.resolve()
1959         sources.extend(
1960             black.gen_python_files(
1961                 path.iterdir(),
1962                 this_abs,
1963                 include,
1964                 exclude,
1965                 None,
1966                 None,
1967                 report,
1968                 gitignore,
1969                 verbose=False,
1970                 quiet=False,
1971             )
1972         )
1973         assert sorted(expected) == sorted(sources)
1974
1975     def test_nested_gitignore(self) -> None:
1976         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1977         include = re.compile(r"\.pyi?$")
1978         exclude = re.compile(r"")
1979         root_gitignore = black.files.get_gitignore(path)
1980         report = black.Report()
1981         expected: List[Path] = [
1982             Path(path / "x.py"),
1983             Path(path / "root/b.py"),
1984             Path(path / "root/c.py"),
1985             Path(path / "root/child/c.py"),
1986         ]
1987         this_abs = THIS_DIR.resolve()
1988         sources = list(
1989             black.gen_python_files(
1990                 path.iterdir(),
1991                 this_abs,
1992                 include,
1993                 exclude,
1994                 None,
1995                 None,
1996                 report,
1997                 root_gitignore,
1998                 verbose=False,
1999                 quiet=False,
2000             )
2001         )
2002         assert sorted(expected) == sorted(sources)
2003
2004     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2005         # https://github.com/psf/black/issues/2598
2006         path = Path(DATA_DIR / "nested_gitignore_tests")
2007         src = Path(path / "root" / "child")
2008         expected = [src / "a.py", src / "c.py"]
2009         assert_collected_sources([src], expected)
2010
2011     def test_invalid_gitignore(self) -> None:
2012         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2013         empty_config = path / "pyproject.toml"
2014         result = BlackRunner().invoke(
2015             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2016         )
2017         assert result.exit_code == 1
2018         assert result.stderr_bytes is not None
2019
2020         gitignore = path / ".gitignore"
2021         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2022
2023     def test_invalid_nested_gitignore(self) -> None:
2024         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2025         empty_config = path / "pyproject.toml"
2026         result = BlackRunner().invoke(
2027             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2028         )
2029         assert result.exit_code == 1
2030         assert result.stderr_bytes is not None
2031
2032         gitignore = path / "a" / ".gitignore"
2033         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2034
2035     def test_empty_include(self) -> None:
2036         path = DATA_DIR / "include_exclude_tests"
2037         src = [path]
2038         expected = [
2039             Path(path / "b/exclude/a.pie"),
2040             Path(path / "b/exclude/a.py"),
2041             Path(path / "b/exclude/a.pyi"),
2042             Path(path / "b/dont_exclude/a.pie"),
2043             Path(path / "b/dont_exclude/a.py"),
2044             Path(path / "b/dont_exclude/a.pyi"),
2045             Path(path / "b/.definitely_exclude/a.pie"),
2046             Path(path / "b/.definitely_exclude/a.py"),
2047             Path(path / "b/.definitely_exclude/a.pyi"),
2048             Path(path / ".gitignore"),
2049             Path(path / "pyproject.toml"),
2050         ]
2051         # Setting exclude explicitly to an empty string to block .gitignore usage.
2052         assert_collected_sources(src, expected, include="", exclude="")
2053
2054     def test_extend_exclude(self) -> None:
2055         path = DATA_DIR / "include_exclude_tests"
2056         src = [path]
2057         expected = [
2058             Path(path / "b/exclude/a.py"),
2059             Path(path / "b/dont_exclude/a.py"),
2060         ]
2061         assert_collected_sources(
2062             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2063         )
2064
2065     @pytest.mark.incompatible_with_mypyc
2066     def test_symlink_out_of_root_directory(self) -> None:
2067         path = MagicMock()
2068         root = THIS_DIR.resolve()
2069         child = MagicMock()
2070         include = re.compile(black.DEFAULT_INCLUDES)
2071         exclude = re.compile(black.DEFAULT_EXCLUDES)
2072         report = black.Report()
2073         gitignore = PathSpec.from_lines("gitwildmatch", [])
2074         # `child` should behave like a symlink which resolved path is clearly
2075         # outside of the `root` directory.
2076         path.iterdir.return_value = [child]
2077         child.resolve.return_value = Path("/a/b/c")
2078         child.as_posix.return_value = "/a/b/c"
2079         try:
2080             list(
2081                 black.gen_python_files(
2082                     path.iterdir(),
2083                     root,
2084                     include,
2085                     exclude,
2086                     None,
2087                     None,
2088                     report,
2089                     gitignore,
2090                     verbose=False,
2091                     quiet=False,
2092                 )
2093             )
2094         except ValueError as ve:
2095             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2096         path.iterdir.assert_called_once()
2097         child.resolve.assert_called_once()
2098
2099     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2100     def test_get_sources_with_stdin(self) -> None:
2101         src = ["-"]
2102         expected = ["-"]
2103         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2104
2105     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2106     def test_get_sources_with_stdin_filename(self) -> None:
2107         src = ["-"]
2108         stdin_filename = str(THIS_DIR / "data/collections.py")
2109         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2110         assert_collected_sources(
2111             src,
2112             expected,
2113             exclude=r"/exclude/a\.py",
2114             stdin_filename=stdin_filename,
2115         )
2116
2117     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2118     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2119         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2120         # file being passed directly. This is the same as
2121         # test_exclude_for_issue_1572
2122         path = DATA_DIR / "include_exclude_tests"
2123         src = ["-"]
2124         stdin_filename = str(path / "b/exclude/a.py")
2125         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2126         assert_collected_sources(
2127             src,
2128             expected,
2129             exclude=r"/exclude/|a\.py",
2130             stdin_filename=stdin_filename,
2131         )
2132
2133     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2134     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2135         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2136         # file being passed directly. This is the same as
2137         # test_exclude_for_issue_1572
2138         src = ["-"]
2139         path = THIS_DIR / "data" / "include_exclude_tests"
2140         stdin_filename = str(path / "b/exclude/a.py")
2141         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2142         assert_collected_sources(
2143             src,
2144             expected,
2145             extend_exclude=r"/exclude/|a\.py",
2146             stdin_filename=stdin_filename,
2147         )
2148
2149     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2150     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2151         # Force exclude should exclude the file when passing it through
2152         # stdin_filename
2153         path = THIS_DIR / "data" / "include_exclude_tests"
2154         stdin_filename = str(path / "b/exclude/a.py")
2155         assert_collected_sources(
2156             src=["-"],
2157             expected=[],
2158             force_exclude=r"/exclude/|a\.py",
2159             stdin_filename=stdin_filename,
2160         )
2161
2162
2163 try:
2164     with open(black.__file__, "r", encoding="utf-8") as _bf:
2165         black_source_lines = _bf.readlines()
2166 except UnicodeDecodeError:
2167     if not black.COMPILED:
2168         raise
2169
2170
2171 def tracefunc(
2172     frame: types.FrameType, event: str, arg: Any
2173 ) -> Callable[[types.FrameType, str, Any], Any]:
2174     """Show function calls `from black/__init__.py` as they happen.
2175
2176     Register this with `sys.settrace()` in a test you're debugging.
2177     """
2178     if event != "call":
2179         return tracefunc
2180
2181     stack = len(inspect.stack()) - 19
2182     stack *= 2
2183     filename = frame.f_code.co_filename
2184     lineno = frame.f_lineno
2185     func_sig_lineno = lineno - 1
2186     funcname = black_source_lines[func_sig_lineno].strip()
2187     while funcname.startswith("@"):
2188         func_sig_lineno += 1
2189         funcname = black_source_lines[func_sig_lineno].strip()
2190     if "black/__init__.py" in filename:
2191         print(f"{' ' * stack}{lineno}:{funcname}")
2192     return tracefunc