]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

c76b3faddf556e1bb114efb59296d132eb9f7f4d
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     TypeVar,
29     Union,
30 )
31 from unittest.mock import MagicMock, patch
32
33 import click
34 import pytest
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     get_case_path,
63     read_data,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_detect_debug_f_strings(self) -> None:
314         root = black.lib2to3_parse("""f"{x=}" """)
315         features = black.get_features_used(root)
316         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
317         versions = black.detect_target_versions(root)
318         self.assertIn(black.TargetVersion.PY38, versions)
319
320         root = black.lib2to3_parse(
321             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
322         )
323         features = black.get_features_used(root)
324         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
325
326         # We don't yet support feature version detection in nested f-strings
327         root = black.lib2to3_parse(
328             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
329         )
330         features = black.get_features_used(root)
331         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_string_quotes(self) -> None:
335         source, expected = read_data("miscellaneous", "string_quotes")
336         mode = black.Mode(preview=True)
337         assert_format(source, expected, mode)
338         mode = replace(mode, string_normalization=False)
339         not_normalized = fs(source, mode=mode)
340         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
341         black.assert_equivalent(source, not_normalized)
342         black.assert_stable(source, not_normalized, mode=mode)
343
344     def test_skip_magic_trailing_comma(self) -> None:
345         source, _ = read_data("simple_cases", "expression")
346         expected, _ = read_data(
347             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
348         )
349         tmp_file = Path(black.dump_to_file(source))
350         diff_header = re.compile(
351             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
352             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
353         )
354         try:
355             result = BlackRunner().invoke(
356                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
357             )
358             self.assertEqual(result.exit_code, 0)
359         finally:
360             os.unlink(tmp_file)
361         actual = result.output
362         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
363         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
364         if expected != actual:
365             dump = black.dump_to_file(actual)
366             msg = (
367                 "Expected diff isn't equal to the actual. If you made changes to"
368                 " expression.py and this is an anticipated difference, overwrite"
369                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
370             )
371             self.assertEqual(expected, actual, msg)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_async_as_identifier(self) -> None:
375         source_path = get_case_path("miscellaneous", "async_as_identifier")
376         source, expected = read_data_from_file(source_path)
377         actual = fs(source)
378         self.assertFormatEqual(expected, actual)
379         major, minor = sys.version_info[:2]
380         if major < 3 or (major <= 3 and minor < 7):
381             black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, DEFAULT_MODE)
383         # ensure black can parse this when the target is 3.6
384         self.invokeBlack([str(source_path), "--target-version", "py36"])
385         # but not on 3.7, because async/await is no longer an identifier
386         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_python37(self) -> None:
390         source_path = get_case_path("py_37", "python37")
391         source, expected = read_data_from_file(source_path)
392         actual = fs(source)
393         self.assertFormatEqual(expected, actual)
394         major, minor = sys.version_info[:2]
395         if major > 3 or (major == 3 and minor >= 7):
396             black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, DEFAULT_MODE)
398         # ensure black can parse this when the target is 3.7
399         self.invokeBlack([str(source_path), "--target-version", "py37"])
400         # but not on 3.6, because we use async as a reserved keyword
401         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
402
403     def test_tab_comment_indentation(self) -> None:
404         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
405         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
406         self.assertFormatEqual(contents_spc, fs(contents_spc))
407         self.assertFormatEqual(contents_spc, fs(contents_tab))
408
409         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
410         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
411         self.assertFormatEqual(contents_spc, fs(contents_spc))
412         self.assertFormatEqual(contents_spc, fs(contents_tab))
413
414         # mixed tabs and spaces (valid Python 2 code)
415         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
416         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
417         self.assertFormatEqual(contents_spc, fs(contents_spc))
418         self.assertFormatEqual(contents_spc, fs(contents_tab))
419
420         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
421         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
422         self.assertFormatEqual(contents_spc, fs(contents_spc))
423         self.assertFormatEqual(contents_spc, fs(contents_tab))
424
425     def test_report_verbose(self) -> None:
426         report = Report(verbose=True)
427         out_lines = []
428         err_lines = []
429
430         def out(msg: str, **kwargs: Any) -> None:
431             out_lines.append(msg)
432
433         def err(msg: str, **kwargs: Any) -> None:
434             err_lines.append(msg)
435
436         with patch("black.output._out", out), patch("black.output._err", err):
437             report.done(Path("f1"), black.Changed.NO)
438             self.assertEqual(len(out_lines), 1)
439             self.assertEqual(len(err_lines), 0)
440             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
441             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
442             self.assertEqual(report.return_code, 0)
443             report.done(Path("f2"), black.Changed.YES)
444             self.assertEqual(len(out_lines), 2)
445             self.assertEqual(len(err_lines), 0)
446             self.assertEqual(out_lines[-1], "reformatted f2")
447             self.assertEqual(
448                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
449             )
450             report.done(Path("f3"), black.Changed.CACHED)
451             self.assertEqual(len(out_lines), 3)
452             self.assertEqual(len(err_lines), 0)
453             self.assertEqual(
454                 out_lines[-1], "f3 wasn't modified on disk since last run."
455             )
456             self.assertEqual(
457                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
458             )
459             self.assertEqual(report.return_code, 0)
460             report.check = True
461             self.assertEqual(report.return_code, 1)
462             report.check = False
463             report.failed(Path("e1"), "boom")
464             self.assertEqual(len(out_lines), 3)
465             self.assertEqual(len(err_lines), 1)
466             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
467             self.assertEqual(
468                 unstyle(str(report)),
469                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
470                 " reformat.",
471             )
472             self.assertEqual(report.return_code, 123)
473             report.done(Path("f3"), black.Changed.YES)
474             self.assertEqual(len(out_lines), 4)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(out_lines[-1], "reformatted f3")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.failed(Path("e2"), "boom")
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 2)
486             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.path_ignored(Path("wat"), "no match")
494             self.assertEqual(len(out_lines), 5)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(out_lines[-1], "wat ignored: no match")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.done(Path("f4"), black.Changed.NO)
504             self.assertEqual(len(out_lines), 6)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.check = True
514             self.assertEqual(
515                 unstyle(str(report)),
516                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
517                 " would fail to reformat.",
518             )
519             report.check = False
520             report.diff = True
521             self.assertEqual(
522                 unstyle(str(report)),
523                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
524                 " would fail to reformat.",
525             )
526
527     def test_report_quiet(self) -> None:
528         report = Report(quiet=True)
529         out_lines = []
530         err_lines = []
531
532         def out(msg: str, **kwargs: Any) -> None:
533             out_lines.append(msg)
534
535         def err(msg: str, **kwargs: Any) -> None:
536             err_lines.append(msg)
537
538         with patch("black.output._out", out), patch("black.output._err", err):
539             report.done(Path("f1"), black.Changed.NO)
540             self.assertEqual(len(out_lines), 0)
541             self.assertEqual(len(err_lines), 0)
542             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
543             self.assertEqual(report.return_code, 0)
544             report.done(Path("f2"), black.Changed.YES)
545             self.assertEqual(len(out_lines), 0)
546             self.assertEqual(len(err_lines), 0)
547             self.assertEqual(
548                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
549             )
550             report.done(Path("f3"), black.Changed.CACHED)
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 0)
553             self.assertEqual(
554                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
555             )
556             self.assertEqual(report.return_code, 0)
557             report.check = True
558             self.assertEqual(report.return_code, 1)
559             report.check = False
560             report.failed(Path("e1"), "boom")
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 1)
563             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
564             self.assertEqual(
565                 unstyle(str(report)),
566                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
567                 " reformat.",
568             )
569             self.assertEqual(report.return_code, 123)
570             report.done(Path("f3"), black.Changed.YES)
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(
574                 unstyle(str(report)),
575                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
576                 " reformat.",
577             )
578             self.assertEqual(report.return_code, 123)
579             report.failed(Path("e2"), "boom")
580             self.assertEqual(len(out_lines), 0)
581             self.assertEqual(len(err_lines), 2)
582             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.path_ignored(Path("wat"), "no match")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(
593                 unstyle(str(report)),
594                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
595                 " reformat.",
596             )
597             self.assertEqual(report.return_code, 123)
598             report.done(Path("f4"), black.Changed.NO)
599             self.assertEqual(len(out_lines), 0)
600             self.assertEqual(len(err_lines), 2)
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
604                 " reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.check = True
608             self.assertEqual(
609                 unstyle(str(report)),
610                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
611                 " would fail to reformat.",
612             )
613             report.check = False
614             report.diff = True
615             self.assertEqual(
616                 unstyle(str(report)),
617                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
618                 " would fail to reformat.",
619             )
620
621     def test_report_normal(self) -> None:
622         report = black.Report()
623         out_lines = []
624         err_lines = []
625
626         def out(msg: str, **kwargs: Any) -> None:
627             out_lines.append(msg)
628
629         def err(msg: str, **kwargs: Any) -> None:
630             err_lines.append(msg)
631
632         with patch("black.output._out", out), patch("black.output._err", err):
633             report.done(Path("f1"), black.Changed.NO)
634             self.assertEqual(len(out_lines), 0)
635             self.assertEqual(len(err_lines), 0)
636             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
637             self.assertEqual(report.return_code, 0)
638             report.done(Path("f2"), black.Changed.YES)
639             self.assertEqual(len(out_lines), 1)
640             self.assertEqual(len(err_lines), 0)
641             self.assertEqual(out_lines[-1], "reformatted f2")
642             self.assertEqual(
643                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
644             )
645             report.done(Path("f3"), black.Changed.CACHED)
646             self.assertEqual(len(out_lines), 1)
647             self.assertEqual(len(err_lines), 0)
648             self.assertEqual(out_lines[-1], "reformatted f2")
649             self.assertEqual(
650                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
651             )
652             self.assertEqual(report.return_code, 0)
653             report.check = True
654             self.assertEqual(report.return_code, 1)
655             report.check = False
656             report.failed(Path("e1"), "boom")
657             self.assertEqual(len(out_lines), 1)
658             self.assertEqual(len(err_lines), 1)
659             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.done(Path("f3"), black.Changed.YES)
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(out_lines[-1], "reformatted f3")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.failed(Path("e2"), "boom")
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 2)
679             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.path_ignored(Path("wat"), "no match")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(
690                 unstyle(str(report)),
691                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
692                 " reformat.",
693             )
694             self.assertEqual(report.return_code, 123)
695             report.done(Path("f4"), black.Changed.NO)
696             self.assertEqual(len(out_lines), 2)
697             self.assertEqual(len(err_lines), 2)
698             self.assertEqual(
699                 unstyle(str(report)),
700                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
701                 " reformat.",
702             )
703             self.assertEqual(report.return_code, 123)
704             report.check = True
705             self.assertEqual(
706                 unstyle(str(report)),
707                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
708                 " would fail to reformat.",
709             )
710             report.check = False
711             report.diff = True
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
715                 " would fail to reformat.",
716             )
717
718     def test_lib2to3_parse(self) -> None:
719         with self.assertRaises(black.InvalidInput):
720             black.lib2to3_parse("invalid syntax")
721
722         straddling = "x + y"
723         black.lib2to3_parse(straddling)
724         black.lib2to3_parse(straddling, {TargetVersion.PY36})
725
726         py2_only = "print x"
727         with self.assertRaises(black.InvalidInput):
728             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
729
730         py3_only = "exec(x, end=y)"
731         black.lib2to3_parse(py3_only)
732         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
733
734     def test_get_features_used_decorator(self) -> None:
735         # Test the feature detection of new decorator syntax
736         # since this makes some test cases of test_get_features_used()
737         # fails if it fails, this is tested first so that a useful case
738         # is identified
739         simples, relaxed = read_data("miscellaneous", "decorators")
740         # skip explanation comments at the top of the file
741         for simple_test in simples.split("##")[1:]:
742             node = black.lib2to3_parse(simple_test)
743             decorator = str(node.children[0].children[0]).strip()
744             self.assertNotIn(
745                 Feature.RELAXED_DECORATORS,
746                 black.get_features_used(node),
747                 msg=(
748                     f"decorator '{decorator}' follows python<=3.8 syntax"
749                     "but is detected as 3.9+"
750                     # f"The full node is\n{node!r}"
751                 ),
752             )
753         # skip the '# output' comment at the top of the output part
754         for relaxed_test in relaxed.split("##")[1:]:
755             node = black.lib2to3_parse(relaxed_test)
756             decorator = str(node.children[0].children[0]).strip()
757             self.assertIn(
758                 Feature.RELAXED_DECORATORS,
759                 black.get_features_used(node),
760                 msg=(
761                     f"decorator '{decorator}' uses python3.9+ syntax"
762                     "but is detected as python<=3.8"
763                     # f"The full node is\n{node!r}"
764                 ),
765             )
766
767     def test_get_features_used(self) -> None:
768         node = black.lib2to3_parse("def f(*, arg): ...\n")
769         self.assertEqual(black.get_features_used(node), set())
770         node = black.lib2to3_parse("def f(*, arg,): ...\n")
771         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
772         node = black.lib2to3_parse("f(*arg,)\n")
773         self.assertEqual(
774             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
775         )
776         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
777         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
778         node = black.lib2to3_parse("123_456\n")
779         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
780         node = black.lib2to3_parse("123456\n")
781         self.assertEqual(black.get_features_used(node), set())
782         source, expected = read_data("simple_cases", "function")
783         node = black.lib2to3_parse(source)
784         expected_features = {
785             Feature.TRAILING_COMMA_IN_CALL,
786             Feature.TRAILING_COMMA_IN_DEF,
787             Feature.F_STRINGS,
788         }
789         self.assertEqual(black.get_features_used(node), expected_features)
790         node = black.lib2to3_parse(expected)
791         self.assertEqual(black.get_features_used(node), expected_features)
792         source, expected = read_data("simple_cases", "expression")
793         node = black.lib2to3_parse(source)
794         self.assertEqual(black.get_features_used(node), set())
795         node = black.lib2to3_parse(expected)
796         self.assertEqual(black.get_features_used(node), set())
797         node = black.lib2to3_parse("lambda a, /, b: ...")
798         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
799         node = black.lib2to3_parse("def fn(a, /, b): ...")
800         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
801         node = black.lib2to3_parse("def fn(): yield a, b")
802         self.assertEqual(black.get_features_used(node), set())
803         node = black.lib2to3_parse("def fn(): return a, b")
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse("def fn(): yield *b, c")
806         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
807         node = black.lib2to3_parse("def fn(): return a, *b, c")
808         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
809         node = black.lib2to3_parse("x = a, *b, c")
810         self.assertEqual(black.get_features_used(node), set())
811         node = black.lib2to3_parse("x: Any = regular")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("x: Any = (regular, regular)")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
816         self.assertEqual(black.get_features_used(node), set())
817         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
818         self.assertEqual(
819             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
820         )
821         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
826         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
827         node = black.lib2to3_parse("a[*b]")
828         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
829         node = black.lib2to3_parse("a[x, *y(), z] = t")
830         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
831         node = black.lib2to3_parse("def fn(*args: *T): pass")
832         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
833
834     def test_get_features_used_for_future_flags(self) -> None:
835         for src, features in [
836             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
837             (
838                 "from __future__ import (other, annotations)",
839                 {Feature.FUTURE_ANNOTATIONS},
840             ),
841             ("a = 1 + 2\nfrom something import annotations", set()),
842             ("from __future__ import x, y", set()),
843         ]:
844             with self.subTest(src=src, features=features):
845                 node = black.lib2to3_parse(src)
846                 future_imports = black.get_future_imports(node)
847                 self.assertEqual(
848                     black.get_features_used(node, future_imports=future_imports),
849                     features,
850                 )
851
852     def test_get_future_imports(self) -> None:
853         node = black.lib2to3_parse("\n")
854         self.assertEqual(set(), black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import black\n")
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
858         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
860         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import multiple\nfrom __future__ import imports\n"
863         )
864         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
865         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
868         self.assertEqual({"black"}, black.get_future_imports(node))
869         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse("from some.module import black\n")
872         self.assertEqual(set(), black.get_future_imports(node))
873         node = black.lib2to3_parse(
874             "from __future__ import unicode_literals as _unicode_literals"
875         )
876         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
877         node = black.lib2to3_parse(
878             "from __future__ import unicode_literals as _lol, print"
879         )
880         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
881
882     @pytest.mark.incompatible_with_mypyc
883     def test_debug_visitor(self) -> None:
884         source, _ = read_data("miscellaneous", "debug_visitor")
885         expected, _ = read_data("miscellaneous", "debug_visitor.out")
886         out_lines = []
887         err_lines = []
888
889         def out(msg: str, **kwargs: Any) -> None:
890             out_lines.append(msg)
891
892         def err(msg: str, **kwargs: Any) -> None:
893             err_lines.append(msg)
894
895         with patch("black.debug.out", out):
896             DebugVisitor.show(source)
897         actual = "\n".join(out_lines) + "\n"
898         log_name = ""
899         if expected != actual:
900             log_name = black.dump_to_file(*out_lines)
901         self.assertEqual(
902             expected,
903             actual,
904             f"AST print out is different. Actual version dumped to {log_name}",
905         )
906
907     def test_format_file_contents(self) -> None:
908         empty = ""
909         mode = DEFAULT_MODE
910         with self.assertRaises(black.NothingChanged):
911             black.format_file_contents(empty, mode=mode, fast=False)
912         just_nl = "\n"
913         with self.assertRaises(black.NothingChanged):
914             black.format_file_contents(just_nl, mode=mode, fast=False)
915         same = "j = [1, 2, 3]\n"
916         with self.assertRaises(black.NothingChanged):
917             black.format_file_contents(same, mode=mode, fast=False)
918         different = "j = [1,2,3]"
919         expected = same
920         actual = black.format_file_contents(different, mode=mode, fast=False)
921         self.assertEqual(expected, actual)
922         invalid = "return if you can"
923         with self.assertRaises(black.InvalidInput) as e:
924             black.format_file_contents(invalid, mode=mode, fast=False)
925         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
926
927     def test_endmarker(self) -> None:
928         n = black.lib2to3_parse("\n")
929         self.assertEqual(n.type, black.syms.file_input)
930         self.assertEqual(len(n.children), 1)
931         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
932
933     @pytest.mark.incompatible_with_mypyc
934     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
935     def test_assertFormatEqual(self) -> None:
936         out_lines = []
937         err_lines = []
938
939         def out(msg: str, **kwargs: Any) -> None:
940             out_lines.append(msg)
941
942         def err(msg: str, **kwargs: Any) -> None:
943             err_lines.append(msg)
944
945         with patch("black.output._out", out), patch("black.output._err", err):
946             with self.assertRaises(AssertionError):
947                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
948
949         out_str = "".join(out_lines)
950         self.assertIn("Expected tree:", out_str)
951         self.assertIn("Actual tree:", out_str)
952         self.assertEqual("".join(err_lines), "")
953
954     @event_loop()
955     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
956     def test_works_in_mono_process_only_environment(self) -> None:
957         with cache_dir() as workspace:
958             for f in [
959                 (workspace / "one.py").resolve(),
960                 (workspace / "two.py").resolve(),
961             ]:
962                 f.write_text('print("hello")\n')
963             self.invokeBlack([str(workspace)])
964
965     @event_loop()
966     def test_check_diff_use_together(self) -> None:
967         with cache_dir():
968             # Files which will be reformatted.
969             src1 = get_case_path("miscellaneous", "string_quotes")
970             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
971             # Files which will not be reformatted.
972             src2 = get_case_path("simple_cases", "composition")
973             self.invokeBlack([str(src2), "--diff", "--check"])
974             # Multi file command.
975             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
976
977     def test_no_src_fails(self) -> None:
978         with cache_dir():
979             self.invokeBlack([], exit_code=1)
980
981     def test_src_and_code_fails(self) -> None:
982         with cache_dir():
983             self.invokeBlack([".", "-c", "0"], exit_code=1)
984
985     def test_broken_symlink(self) -> None:
986         with cache_dir() as workspace:
987             symlink = workspace / "broken_link.py"
988             try:
989                 symlink.symlink_to("nonexistent.py")
990             except (OSError, NotImplementedError) as e:
991                 self.skipTest(f"Can't create symlinks: {e}")
992             self.invokeBlack([str(workspace.resolve())])
993
994     def test_single_file_force_pyi(self) -> None:
995         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
996         contents, expected = read_data("miscellaneous", "force_pyi")
997         with cache_dir() as workspace:
998             path = (workspace / "file.py").resolve()
999             with open(path, "w") as fh:
1000                 fh.write(contents)
1001             self.invokeBlack([str(path), "--pyi"])
1002             with open(path, "r") as fh:
1003                 actual = fh.read()
1004             # verify cache with --pyi is separate
1005             pyi_cache = black.read_cache(pyi_mode)
1006             self.assertIn(str(path), pyi_cache)
1007             normal_cache = black.read_cache(DEFAULT_MODE)
1008             self.assertNotIn(str(path), normal_cache)
1009         self.assertFormatEqual(expected, actual)
1010         black.assert_equivalent(contents, actual)
1011         black.assert_stable(contents, actual, pyi_mode)
1012
1013     @event_loop()
1014     def test_multi_file_force_pyi(self) -> None:
1015         reg_mode = DEFAULT_MODE
1016         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1017         contents, expected = read_data("miscellaneous", "force_pyi")
1018         with cache_dir() as workspace:
1019             paths = [
1020                 (workspace / "file1.py").resolve(),
1021                 (workspace / "file2.py").resolve(),
1022             ]
1023             for path in paths:
1024                 with open(path, "w") as fh:
1025                     fh.write(contents)
1026             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1027             for path in paths:
1028                 with open(path, "r") as fh:
1029                     actual = fh.read()
1030                 self.assertEqual(actual, expected)
1031             # verify cache with --pyi is separate
1032             pyi_cache = black.read_cache(pyi_mode)
1033             normal_cache = black.read_cache(reg_mode)
1034             for path in paths:
1035                 self.assertIn(str(path), pyi_cache)
1036                 self.assertNotIn(str(path), normal_cache)
1037
1038     def test_pipe_force_pyi(self) -> None:
1039         source, expected = read_data("miscellaneous", "force_pyi")
1040         result = CliRunner().invoke(
1041             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1042         )
1043         self.assertEqual(result.exit_code, 0)
1044         actual = result.output
1045         self.assertFormatEqual(actual, expected)
1046
1047     def test_single_file_force_py36(self) -> None:
1048         reg_mode = DEFAULT_MODE
1049         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1050         source, expected = read_data("miscellaneous", "force_py36")
1051         with cache_dir() as workspace:
1052             path = (workspace / "file.py").resolve()
1053             with open(path, "w") as fh:
1054                 fh.write(source)
1055             self.invokeBlack([str(path), *PY36_ARGS])
1056             with open(path, "r") as fh:
1057                 actual = fh.read()
1058             # verify cache with --target-version is separate
1059             py36_cache = black.read_cache(py36_mode)
1060             self.assertIn(str(path), py36_cache)
1061             normal_cache = black.read_cache(reg_mode)
1062             self.assertNotIn(str(path), normal_cache)
1063         self.assertEqual(actual, expected)
1064
1065     @event_loop()
1066     def test_multi_file_force_py36(self) -> None:
1067         reg_mode = DEFAULT_MODE
1068         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1069         source, expected = read_data("miscellaneous", "force_py36")
1070         with cache_dir() as workspace:
1071             paths = [
1072                 (workspace / "file1.py").resolve(),
1073                 (workspace / "file2.py").resolve(),
1074             ]
1075             for path in paths:
1076                 with open(path, "w") as fh:
1077                     fh.write(source)
1078             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1079             for path in paths:
1080                 with open(path, "r") as fh:
1081                     actual = fh.read()
1082                 self.assertEqual(actual, expected)
1083             # verify cache with --target-version is separate
1084             pyi_cache = black.read_cache(py36_mode)
1085             normal_cache = black.read_cache(reg_mode)
1086             for path in paths:
1087                 self.assertIn(str(path), pyi_cache)
1088                 self.assertNotIn(str(path), normal_cache)
1089
1090     def test_pipe_force_py36(self) -> None:
1091         source, expected = read_data("miscellaneous", "force_py36")
1092         result = CliRunner().invoke(
1093             black.main,
1094             ["-", "-q", "--target-version=py36"],
1095             input=BytesIO(source.encode("utf8")),
1096         )
1097         self.assertEqual(result.exit_code, 0)
1098         actual = result.output
1099         self.assertFormatEqual(actual, expected)
1100
1101     @pytest.mark.incompatible_with_mypyc
1102     def test_reformat_one_with_stdin(self) -> None:
1103         with patch(
1104             "black.format_stdin_to_stdout",
1105             return_value=lambda *args, **kwargs: black.Changed.YES,
1106         ) as fsts:
1107             report = MagicMock()
1108             path = Path("-")
1109             black.reformat_one(
1110                 path,
1111                 fast=True,
1112                 write_back=black.WriteBack.YES,
1113                 mode=DEFAULT_MODE,
1114                 report=report,
1115             )
1116             fsts.assert_called_once()
1117             report.done.assert_called_with(path, black.Changed.YES)
1118
1119     @pytest.mark.incompatible_with_mypyc
1120     def test_reformat_one_with_stdin_filename(self) -> None:
1121         with patch(
1122             "black.format_stdin_to_stdout",
1123             return_value=lambda *args, **kwargs: black.Changed.YES,
1124         ) as fsts:
1125             report = MagicMock()
1126             p = "foo.py"
1127             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1128             expected = Path(p)
1129             black.reformat_one(
1130                 path,
1131                 fast=True,
1132                 write_back=black.WriteBack.YES,
1133                 mode=DEFAULT_MODE,
1134                 report=report,
1135             )
1136             fsts.assert_called_once_with(
1137                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1138             )
1139             # __BLACK_STDIN_FILENAME__ should have been stripped
1140             report.done.assert_called_with(expected, black.Changed.YES)
1141
1142     @pytest.mark.incompatible_with_mypyc
1143     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1144         with patch(
1145             "black.format_stdin_to_stdout",
1146             return_value=lambda *args, **kwargs: black.Changed.YES,
1147         ) as fsts:
1148             report = MagicMock()
1149             p = "foo.pyi"
1150             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1151             expected = Path(p)
1152             black.reformat_one(
1153                 path,
1154                 fast=True,
1155                 write_back=black.WriteBack.YES,
1156                 mode=DEFAULT_MODE,
1157                 report=report,
1158             )
1159             fsts.assert_called_once_with(
1160                 fast=True,
1161                 write_back=black.WriteBack.YES,
1162                 mode=replace(DEFAULT_MODE, is_pyi=True),
1163             )
1164             # __BLACK_STDIN_FILENAME__ should have been stripped
1165             report.done.assert_called_with(expected, black.Changed.YES)
1166
1167     @pytest.mark.incompatible_with_mypyc
1168     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1169         with patch(
1170             "black.format_stdin_to_stdout",
1171             return_value=lambda *args, **kwargs: black.Changed.YES,
1172         ) as fsts:
1173             report = MagicMock()
1174             p = "foo.ipynb"
1175             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1176             expected = Path(p)
1177             black.reformat_one(
1178                 path,
1179                 fast=True,
1180                 write_back=black.WriteBack.YES,
1181                 mode=DEFAULT_MODE,
1182                 report=report,
1183             )
1184             fsts.assert_called_once_with(
1185                 fast=True,
1186                 write_back=black.WriteBack.YES,
1187                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1188             )
1189             # __BLACK_STDIN_FILENAME__ should have been stripped
1190             report.done.assert_called_with(expected, black.Changed.YES)
1191
1192     @pytest.mark.incompatible_with_mypyc
1193     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1194         with patch(
1195             "black.format_stdin_to_stdout",
1196             return_value=lambda *args, **kwargs: black.Changed.YES,
1197         ) as fsts:
1198             report = MagicMock()
1199             # Even with an existing file, since we are forcing stdin, black
1200             # should output to stdout and not modify the file inplace
1201             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1202             # Make sure is_file actually returns True
1203             self.assertTrue(p.is_file())
1204             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1205             expected = Path(p)
1206             black.reformat_one(
1207                 path,
1208                 fast=True,
1209                 write_back=black.WriteBack.YES,
1210                 mode=DEFAULT_MODE,
1211                 report=report,
1212             )
1213             fsts.assert_called_once()
1214             # __BLACK_STDIN_FILENAME__ should have been stripped
1215             report.done.assert_called_with(expected, black.Changed.YES)
1216
1217     def test_reformat_one_with_stdin_empty(self) -> None:
1218         output = io.StringIO()
1219         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1220             try:
1221                 black.format_stdin_to_stdout(
1222                     fast=True,
1223                     content="",
1224                     write_back=black.WriteBack.YES,
1225                     mode=DEFAULT_MODE,
1226                 )
1227             except io.UnsupportedOperation:
1228                 pass  # StringIO does not support detach
1229             assert output.getvalue() == ""
1230
1231     def test_invalid_cli_regex(self) -> None:
1232         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1233             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1234
1235     def test_required_version_matches_version(self) -> None:
1236         self.invokeBlack(
1237             ["--required-version", black.__version__, "-c", "0"],
1238             exit_code=0,
1239             ignore_config=True,
1240         )
1241
1242     def test_required_version_matches_partial_version(self) -> None:
1243         self.invokeBlack(
1244             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1245             exit_code=0,
1246             ignore_config=True,
1247         )
1248
1249     def test_required_version_does_not_match_on_minor_version(self) -> None:
1250         self.invokeBlack(
1251             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1252             exit_code=1,
1253             ignore_config=True,
1254         )
1255
1256     def test_required_version_does_not_match_version(self) -> None:
1257         result = BlackRunner().invoke(
1258             black.main,
1259             ["--required-version", "20.99b", "-c", "0"],
1260         )
1261         self.assertEqual(result.exit_code, 1)
1262         self.assertIn("required version", result.stderr)
1263
1264     def test_preserves_line_endings(self) -> None:
1265         with TemporaryDirectory() as workspace:
1266             test_file = Path(workspace) / "test.py"
1267             for nl in ["\n", "\r\n"]:
1268                 contents = nl.join(["def f(  ):", "    pass"])
1269                 test_file.write_bytes(contents.encode())
1270                 ff(test_file, write_back=black.WriteBack.YES)
1271                 updated_contents: bytes = test_file.read_bytes()
1272                 self.assertIn(nl.encode(), updated_contents)
1273                 if nl == "\n":
1274                     self.assertNotIn(b"\r\n", updated_contents)
1275
1276     def test_preserves_line_endings_via_stdin(self) -> None:
1277         for nl in ["\n", "\r\n"]:
1278             contents = nl.join(["def f(  ):", "    pass"])
1279             runner = BlackRunner()
1280             result = runner.invoke(
1281                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1282             )
1283             self.assertEqual(result.exit_code, 0)
1284             output = result.stdout_bytes
1285             self.assertIn(nl.encode("utf8"), output)
1286             if nl == "\n":
1287                 self.assertNotIn(b"\r\n", output)
1288
1289     def test_assert_equivalent_different_asts(self) -> None:
1290         with self.assertRaises(AssertionError):
1291             black.assert_equivalent("{}", "None")
1292
1293     def test_shhh_click(self) -> None:
1294         try:
1295             from click import _unicodefun  # type: ignore
1296         except ImportError:
1297             self.skipTest("Incompatible Click version")
1298
1299         if not hasattr(_unicodefun, "_verify_python_env"):
1300             self.skipTest("Incompatible Click version")
1301
1302         # First, let's see if Click is crashing with a preferred ASCII charset.
1303         with patch("locale.getpreferredencoding") as gpe:
1304             gpe.return_value = "ASCII"
1305             with self.assertRaises(RuntimeError):
1306                 _unicodefun._verify_python_env()
1307         # Now, let's silence Click...
1308         black.patch_click()
1309         # ...and confirm it's silent.
1310         with patch("locale.getpreferredencoding") as gpe:
1311             gpe.return_value = "ASCII"
1312             try:
1313                 _unicodefun._verify_python_env()
1314             except RuntimeError as re:
1315                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1316
1317     def test_root_logger_not_used_directly(self) -> None:
1318         def fail(*args: Any, **kwargs: Any) -> None:
1319             self.fail("Record created with root logger")
1320
1321         with patch.multiple(
1322             logging.root,
1323             debug=fail,
1324             info=fail,
1325             warning=fail,
1326             error=fail,
1327             critical=fail,
1328             log=fail,
1329         ):
1330             ff(THIS_DIR / "util.py")
1331
1332     def test_invalid_config_return_code(self) -> None:
1333         tmp_file = Path(black.dump_to_file())
1334         try:
1335             tmp_config = Path(black.dump_to_file())
1336             tmp_config.unlink()
1337             args = ["--config", str(tmp_config), str(tmp_file)]
1338             self.invokeBlack(args, exit_code=2, ignore_config=False)
1339         finally:
1340             tmp_file.unlink()
1341
1342     def test_parse_pyproject_toml(self) -> None:
1343         test_toml_file = THIS_DIR / "test.toml"
1344         config = black.parse_pyproject_toml(str(test_toml_file))
1345         self.assertEqual(config["verbose"], 1)
1346         self.assertEqual(config["check"], "no")
1347         self.assertEqual(config["diff"], "y")
1348         self.assertEqual(config["color"], True)
1349         self.assertEqual(config["line_length"], 79)
1350         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1351         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1352         self.assertEqual(config["exclude"], r"\.pyi?$")
1353         self.assertEqual(config["include"], r"\.py?$")
1354
1355     def test_read_pyproject_toml(self) -> None:
1356         test_toml_file = THIS_DIR / "test.toml"
1357         fake_ctx = FakeContext()
1358         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1359         config = fake_ctx.default_map
1360         self.assertEqual(config["verbose"], "1")
1361         self.assertEqual(config["check"], "no")
1362         self.assertEqual(config["diff"], "y")
1363         self.assertEqual(config["color"], "True")
1364         self.assertEqual(config["line_length"], "79")
1365         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1366         self.assertEqual(config["exclude"], r"\.pyi?$")
1367         self.assertEqual(config["include"], r"\.py?$")
1368
1369     @pytest.mark.incompatible_with_mypyc
1370     def test_find_project_root(self) -> None:
1371         with TemporaryDirectory() as workspace:
1372             root = Path(workspace)
1373             test_dir = root / "test"
1374             test_dir.mkdir()
1375
1376             src_dir = root / "src"
1377             src_dir.mkdir()
1378
1379             root_pyproject = root / "pyproject.toml"
1380             root_pyproject.touch()
1381             src_pyproject = src_dir / "pyproject.toml"
1382             src_pyproject.touch()
1383             src_python = src_dir / "foo.py"
1384             src_python.touch()
1385
1386             self.assertEqual(
1387                 black.find_project_root((src_dir, test_dir)),
1388                 (root.resolve(), "pyproject.toml"),
1389             )
1390             self.assertEqual(
1391                 black.find_project_root((src_dir,)),
1392                 (src_dir.resolve(), "pyproject.toml"),
1393             )
1394             self.assertEqual(
1395                 black.find_project_root((src_python,)),
1396                 (src_dir.resolve(), "pyproject.toml"),
1397             )
1398
1399             with change_directory(test_dir):
1400                 self.assertEqual(
1401                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1402                     (src_dir.resolve(), "pyproject.toml"),
1403                 )
1404
1405     @patch(
1406         "black.files.find_user_pyproject_toml",
1407     )
1408     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1409         find_user_pyproject_toml.side_effect = RuntimeError()
1410
1411         with redirect_stderr(io.StringIO()) as stderr:
1412             result = black.files.find_pyproject_toml(
1413                 path_search_start=(str(Path.cwd().root),)
1414             )
1415
1416         assert result is None
1417         err = stderr.getvalue()
1418         assert "Ignoring user configuration" in err
1419
1420     @patch(
1421         "black.files.find_user_pyproject_toml",
1422         black.files.find_user_pyproject_toml.__wrapped__,
1423     )
1424     def test_find_user_pyproject_toml_linux(self) -> None:
1425         if system() == "Windows":
1426             return
1427
1428         # Test if XDG_CONFIG_HOME is checked
1429         with TemporaryDirectory() as workspace:
1430             tmp_user_config = Path(workspace) / "black"
1431             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1432                 self.assertEqual(
1433                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1434                 )
1435
1436         # Test fallback for XDG_CONFIG_HOME
1437         with patch.dict("os.environ"):
1438             os.environ.pop("XDG_CONFIG_HOME", None)
1439             fallback_user_config = Path("~/.config").expanduser() / "black"
1440             self.assertEqual(
1441                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1442             )
1443
1444     def test_find_user_pyproject_toml_windows(self) -> None:
1445         if system() != "Windows":
1446             return
1447
1448         user_config_path = Path.home() / ".black"
1449         self.assertEqual(
1450             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1451         )
1452
1453     def test_bpo_33660_workaround(self) -> None:
1454         if system() == "Windows":
1455             return
1456
1457         # https://bugs.python.org/issue33660
1458         root = Path("/")
1459         with change_directory(root):
1460             path = Path("workspace") / "project"
1461             report = black.Report(verbose=True)
1462             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1463             self.assertEqual(normalized_path, "workspace/project")
1464
1465     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1466         if system() != "Windows":
1467             return
1468
1469         with TemporaryDirectory() as workspace:
1470             root = Path(workspace)
1471             junction_dir = root / "junction"
1472             junction_target_outside_of_root = root / ".."
1473             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1474
1475             report = black.Report(verbose=True)
1476             normalized_path = black.normalize_path_maybe_ignore(
1477                 junction_dir, root, report
1478             )
1479             # Manually delete for Python < 3.8
1480             os.system(f"rmdir {junction_dir}")
1481
1482             self.assertEqual(normalized_path, None)
1483
1484     def test_newline_comment_interaction(self) -> None:
1485         source = "class A:\\\r\n# type: ignore\n pass\n"
1486         output = black.format_str(source, mode=DEFAULT_MODE)
1487         black.assert_stable(source, output, mode=DEFAULT_MODE)
1488
1489     def test_bpo_2142_workaround(self) -> None:
1490         # https://bugs.python.org/issue2142
1491
1492         source, _ = read_data("miscellaneous", "missing_final_newline")
1493         # read_data adds a trailing newline
1494         source = source.rstrip()
1495         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1496         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1497         diff_header = re.compile(
1498             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1499             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1500         )
1501         try:
1502             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1503             self.assertEqual(result.exit_code, 0)
1504         finally:
1505             os.unlink(tmp_file)
1506         actual = result.output
1507         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1508         self.assertEqual(actual, expected)
1509
1510     @staticmethod
1511     def compare_results(
1512         result: click.testing.Result, expected_value: str, expected_exit_code: int
1513     ) -> None:
1514         """Helper method to test the value and exit code of a click Result."""
1515         assert (
1516             result.output == expected_value
1517         ), "The output did not match the expected value."
1518         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1519
1520     def test_code_option(self) -> None:
1521         """Test the code option with no changes."""
1522         code = 'print("Hello world")\n'
1523         args = ["--code", code]
1524         result = CliRunner().invoke(black.main, args)
1525
1526         self.compare_results(result, code, 0)
1527
1528     def test_code_option_changed(self) -> None:
1529         """Test the code option when changes are required."""
1530         code = "print('hello world')"
1531         formatted = black.format_str(code, mode=DEFAULT_MODE)
1532
1533         args = ["--code", code]
1534         result = CliRunner().invoke(black.main, args)
1535
1536         self.compare_results(result, formatted, 0)
1537
1538     def test_code_option_check(self) -> None:
1539         """Test the code option when check is passed."""
1540         args = ["--check", "--code", 'print("Hello world")\n']
1541         result = CliRunner().invoke(black.main, args)
1542         self.compare_results(result, "", 0)
1543
1544     def test_code_option_check_changed(self) -> None:
1545         """Test the code option when changes are required, and check is passed."""
1546         args = ["--check", "--code", "print('hello world')"]
1547         result = CliRunner().invoke(black.main, args)
1548         self.compare_results(result, "", 1)
1549
1550     def test_code_option_diff(self) -> None:
1551         """Test the code option when diff is passed."""
1552         code = "print('hello world')"
1553         formatted = black.format_str(code, mode=DEFAULT_MODE)
1554         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1555
1556         args = ["--diff", "--code", code]
1557         result = CliRunner().invoke(black.main, args)
1558
1559         # Remove time from diff
1560         output = DIFF_TIME.sub("", result.output)
1561
1562         assert output == result_diff, "The output did not match the expected value."
1563         assert result.exit_code == 0, "The exit code is incorrect."
1564
1565     def test_code_option_color_diff(self) -> None:
1566         """Test the code option when color and diff are passed."""
1567         code = "print('hello world')"
1568         formatted = black.format_str(code, mode=DEFAULT_MODE)
1569
1570         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1571         result_diff = color_diff(result_diff)
1572
1573         args = ["--diff", "--color", "--code", code]
1574         result = CliRunner().invoke(black.main, args)
1575
1576         # Remove time from diff
1577         output = DIFF_TIME.sub("", result.output)
1578
1579         assert output == result_diff, "The output did not match the expected value."
1580         assert result.exit_code == 0, "The exit code is incorrect."
1581
1582     @pytest.mark.incompatible_with_mypyc
1583     def test_code_option_safe(self) -> None:
1584         """Test that the code option throws an error when the sanity checks fail."""
1585         # Patch black.assert_equivalent to ensure the sanity checks fail
1586         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1587             code = 'print("Hello world")'
1588             error_msg = f"{code}\nerror: cannot format <string>: \n"
1589
1590             args = ["--safe", "--code", code]
1591             result = CliRunner().invoke(black.main, args)
1592
1593             self.compare_results(result, error_msg, 123)
1594
1595     def test_code_option_fast(self) -> None:
1596         """Test that the code option ignores errors when the sanity checks fail."""
1597         # Patch black.assert_equivalent to ensure the sanity checks fail
1598         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1599             code = 'print("Hello world")'
1600             formatted = black.format_str(code, mode=DEFAULT_MODE)
1601
1602             args = ["--fast", "--code", code]
1603             result = CliRunner().invoke(black.main, args)
1604
1605             self.compare_results(result, formatted, 0)
1606
1607     @pytest.mark.incompatible_with_mypyc
1608     def test_code_option_config(self) -> None:
1609         """
1610         Test that the code option finds the pyproject.toml in the current directory.
1611         """
1612         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1613             args = ["--code", "print"]
1614             # This is the only directory known to contain a pyproject.toml
1615             with change_directory(PROJECT_ROOT):
1616                 CliRunner().invoke(black.main, args)
1617                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1618
1619             assert (
1620                 len(parse.mock_calls) >= 1
1621             ), "Expected config parse to be called with the current directory."
1622
1623             _, call_args, _ = parse.mock_calls[0]
1624             assert (
1625                 call_args[0].lower() == str(pyproject_path).lower()
1626             ), "Incorrect config loaded."
1627
1628     @pytest.mark.incompatible_with_mypyc
1629     def test_code_option_parent_config(self) -> None:
1630         """
1631         Test that the code option finds the pyproject.toml in the parent directory.
1632         """
1633         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1634             with change_directory(THIS_DIR):
1635                 args = ["--code", "print"]
1636                 CliRunner().invoke(black.main, args)
1637
1638                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1639                 assert (
1640                     len(parse.mock_calls) >= 1
1641                 ), "Expected config parse to be called with the current directory."
1642
1643                 _, call_args, _ = parse.mock_calls[0]
1644                 assert (
1645                     call_args[0].lower() == str(pyproject_path).lower()
1646                 ), "Incorrect config loaded."
1647
1648     def test_for_handled_unexpected_eof_error(self) -> None:
1649         """
1650         Test that an unexpected EOF SyntaxError is nicely presented.
1651         """
1652         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1653             black.lib2to3_parse("print(", {})
1654
1655         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1656
1657     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1658         with pytest.raises(AssertionError) as err:
1659             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1660
1661         err.match("--safe")
1662         # Unfortunately the SyntaxError message has changed in newer versions so we
1663         # can't match it directly.
1664         err.match("invalid character")
1665         err.match(r"\(<unknown>, line 1\)")
1666
1667
1668 class TestCaching:
1669     def test_get_cache_dir(
1670         self,
1671         tmp_path: Path,
1672         monkeypatch: pytest.MonkeyPatch,
1673     ) -> None:
1674         # Create multiple cache directories
1675         workspace1 = tmp_path / "ws1"
1676         workspace1.mkdir()
1677         workspace2 = tmp_path / "ws2"
1678         workspace2.mkdir()
1679
1680         # Force user_cache_dir to use the temporary directory for easier assertions
1681         patch_user_cache_dir = patch(
1682             target="black.cache.user_cache_dir",
1683             autospec=True,
1684             return_value=str(workspace1),
1685         )
1686
1687         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1688         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1689         with patch_user_cache_dir:
1690             assert get_cache_dir() == workspace1
1691
1692         # If it is set, use the path provided in the env var.
1693         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1694         assert get_cache_dir() == workspace2
1695
1696     def test_cache_broken_file(self) -> None:
1697         mode = DEFAULT_MODE
1698         with cache_dir() as workspace:
1699             cache_file = get_cache_file(mode)
1700             cache_file.write_text("this is not a pickle")
1701             assert black.read_cache(mode) == {}
1702             src = (workspace / "test.py").resolve()
1703             src.write_text("print('hello')")
1704             invokeBlack([str(src)])
1705             cache = black.read_cache(mode)
1706             assert str(src) in cache
1707
1708     def test_cache_single_file_already_cached(self) -> None:
1709         mode = DEFAULT_MODE
1710         with cache_dir() as workspace:
1711             src = (workspace / "test.py").resolve()
1712             src.write_text("print('hello')")
1713             black.write_cache({}, [src], mode)
1714             invokeBlack([str(src)])
1715             assert src.read_text() == "print('hello')"
1716
1717     @event_loop()
1718     def test_cache_multiple_files(self) -> None:
1719         mode = DEFAULT_MODE
1720         with cache_dir() as workspace, patch(
1721             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1722         ):
1723             one = (workspace / "one.py").resolve()
1724             with one.open("w") as fobj:
1725                 fobj.write("print('hello')")
1726             two = (workspace / "two.py").resolve()
1727             with two.open("w") as fobj:
1728                 fobj.write("print('hello')")
1729             black.write_cache({}, [one], mode)
1730             invokeBlack([str(workspace)])
1731             with one.open("r") as fobj:
1732                 assert fobj.read() == "print('hello')"
1733             with two.open("r") as fobj:
1734                 assert fobj.read() == 'print("hello")\n'
1735             cache = black.read_cache(mode)
1736             assert str(one) in cache
1737             assert str(two) in cache
1738
1739     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1740     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1741         mode = DEFAULT_MODE
1742         with cache_dir() as workspace:
1743             src = (workspace / "test.py").resolve()
1744             with src.open("w") as fobj:
1745                 fobj.write("print('hello')")
1746             with patch("black.read_cache") as read_cache, patch(
1747                 "black.write_cache"
1748             ) as write_cache:
1749                 cmd = [str(src), "--diff"]
1750                 if color:
1751                     cmd.append("--color")
1752                 invokeBlack(cmd)
1753                 cache_file = get_cache_file(mode)
1754                 assert cache_file.exists() is False
1755                 write_cache.assert_not_called()
1756                 read_cache.assert_not_called()
1757
1758     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1759     @event_loop()
1760     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1761         with cache_dir() as workspace:
1762             for tag in range(0, 4):
1763                 src = (workspace / f"test{tag}.py").resolve()
1764                 with src.open("w") as fobj:
1765                     fobj.write("print('hello')")
1766             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1767                 cmd = ["--diff", str(workspace)]
1768                 if color:
1769                     cmd.append("--color")
1770                 invokeBlack(cmd, exit_code=0)
1771                 # this isn't quite doing what we want, but if it _isn't_
1772                 # called then we cannot be using the lock it provides
1773                 mgr.assert_called()
1774
1775     def test_no_cache_when_stdin(self) -> None:
1776         mode = DEFAULT_MODE
1777         with cache_dir():
1778             result = CliRunner().invoke(
1779                 black.main, ["-"], input=BytesIO(b"print('hello')")
1780             )
1781             assert not result.exit_code
1782             cache_file = get_cache_file(mode)
1783             assert not cache_file.exists()
1784
1785     def test_read_cache_no_cachefile(self) -> None:
1786         mode = DEFAULT_MODE
1787         with cache_dir():
1788             assert black.read_cache(mode) == {}
1789
1790     def test_write_cache_read_cache(self) -> None:
1791         mode = DEFAULT_MODE
1792         with cache_dir() as workspace:
1793             src = (workspace / "test.py").resolve()
1794             src.touch()
1795             black.write_cache({}, [src], mode)
1796             cache = black.read_cache(mode)
1797             assert str(src) in cache
1798             assert cache[str(src)] == black.get_cache_info(src)
1799
1800     def test_filter_cached(self) -> None:
1801         with TemporaryDirectory() as workspace:
1802             path = Path(workspace)
1803             uncached = (path / "uncached").resolve()
1804             cached = (path / "cached").resolve()
1805             cached_but_changed = (path / "changed").resolve()
1806             uncached.touch()
1807             cached.touch()
1808             cached_but_changed.touch()
1809             cache = {
1810                 str(cached): black.get_cache_info(cached),
1811                 str(cached_but_changed): (0.0, 0),
1812             }
1813             todo, done = black.filter_cached(
1814                 cache, {uncached, cached, cached_but_changed}
1815             )
1816             assert todo == {uncached, cached_but_changed}
1817             assert done == {cached}
1818
1819     def test_write_cache_creates_directory_if_needed(self) -> None:
1820         mode = DEFAULT_MODE
1821         with cache_dir(exists=False) as workspace:
1822             assert not workspace.exists()
1823             black.write_cache({}, [], mode)
1824             assert workspace.exists()
1825
1826     @event_loop()
1827     def test_failed_formatting_does_not_get_cached(self) -> None:
1828         mode = DEFAULT_MODE
1829         with cache_dir() as workspace, patch(
1830             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1831         ):
1832             failing = (workspace / "failing.py").resolve()
1833             with failing.open("w") as fobj:
1834                 fobj.write("not actually python")
1835             clean = (workspace / "clean.py").resolve()
1836             with clean.open("w") as fobj:
1837                 fobj.write('print("hello")\n')
1838             invokeBlack([str(workspace)], exit_code=123)
1839             cache = black.read_cache(mode)
1840             assert str(failing) not in cache
1841             assert str(clean) in cache
1842
1843     def test_write_cache_write_fail(self) -> None:
1844         mode = DEFAULT_MODE
1845         with cache_dir(), patch.object(Path, "open") as mock:
1846             mock.side_effect = OSError
1847             black.write_cache({}, [], mode)
1848
1849     def test_read_cache_line_lengths(self) -> None:
1850         mode = DEFAULT_MODE
1851         short_mode = replace(DEFAULT_MODE, line_length=1)
1852         with cache_dir() as workspace:
1853             path = (workspace / "file.py").resolve()
1854             path.touch()
1855             black.write_cache({}, [path], mode)
1856             one = black.read_cache(mode)
1857             assert str(path) in one
1858             two = black.read_cache(short_mode)
1859             assert str(path) not in two
1860
1861
1862 def assert_collected_sources(
1863     src: Sequence[Union[str, Path]],
1864     expected: Sequence[Union[str, Path]],
1865     *,
1866     ctx: Optional[FakeContext] = None,
1867     exclude: Optional[str] = None,
1868     include: Optional[str] = None,
1869     extend_exclude: Optional[str] = None,
1870     force_exclude: Optional[str] = None,
1871     stdin_filename: Optional[str] = None,
1872 ) -> None:
1873     gs_src = tuple(str(Path(s)) for s in src)
1874     gs_expected = [Path(s) for s in expected]
1875     gs_exclude = None if exclude is None else compile_pattern(exclude)
1876     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1877     gs_extend_exclude = (
1878         None if extend_exclude is None else compile_pattern(extend_exclude)
1879     )
1880     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1881     collected = black.get_sources(
1882         ctx=ctx or FakeContext(),
1883         src=gs_src,
1884         quiet=False,
1885         verbose=False,
1886         include=gs_include,
1887         exclude=gs_exclude,
1888         extend_exclude=gs_extend_exclude,
1889         force_exclude=gs_force_exclude,
1890         report=black.Report(),
1891         stdin_filename=stdin_filename,
1892     )
1893     assert sorted(collected) == sorted(gs_expected)
1894
1895
1896 class TestFileCollection:
1897     def test_include_exclude(self) -> None:
1898         path = THIS_DIR / "data" / "include_exclude_tests"
1899         src = [path]
1900         expected = [
1901             Path(path / "b/dont_exclude/a.py"),
1902             Path(path / "b/dont_exclude/a.pyi"),
1903         ]
1904         assert_collected_sources(
1905             src,
1906             expected,
1907             include=r"\.pyi?$",
1908             exclude=r"/exclude/|/\.definitely_exclude/",
1909         )
1910
1911     def test_gitignore_used_as_default(self) -> None:
1912         base = Path(DATA_DIR / "include_exclude_tests")
1913         expected = [
1914             base / "b/.definitely_exclude/a.py",
1915             base / "b/.definitely_exclude/a.pyi",
1916         ]
1917         src = [base / "b/"]
1918         ctx = FakeContext()
1919         ctx.obj["root"] = base
1920         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1921
1922     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1923     def test_exclude_for_issue_1572(self) -> None:
1924         # Exclude shouldn't touch files that were explicitly given to Black through the
1925         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1926         # https://github.com/psf/black/issues/1572
1927         path = DATA_DIR / "include_exclude_tests"
1928         src = [path / "b/exclude/a.py"]
1929         expected = [path / "b/exclude/a.py"]
1930         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1931
1932     def test_gitignore_exclude(self) -> None:
1933         path = THIS_DIR / "data" / "include_exclude_tests"
1934         include = re.compile(r"\.pyi?$")
1935         exclude = re.compile(r"")
1936         report = black.Report()
1937         gitignore = PathSpec.from_lines(
1938             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1939         )
1940         sources: List[Path] = []
1941         expected = [
1942             Path(path / "b/dont_exclude/a.py"),
1943             Path(path / "b/dont_exclude/a.pyi"),
1944         ]
1945         this_abs = THIS_DIR.resolve()
1946         sources.extend(
1947             black.gen_python_files(
1948                 path.iterdir(),
1949                 this_abs,
1950                 include,
1951                 exclude,
1952                 None,
1953                 None,
1954                 report,
1955                 gitignore,
1956                 verbose=False,
1957                 quiet=False,
1958             )
1959         )
1960         assert sorted(expected) == sorted(sources)
1961
1962     def test_nested_gitignore(self) -> None:
1963         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1964         include = re.compile(r"\.pyi?$")
1965         exclude = re.compile(r"")
1966         root_gitignore = black.files.get_gitignore(path)
1967         report = black.Report()
1968         expected: List[Path] = [
1969             Path(path / "x.py"),
1970             Path(path / "root/b.py"),
1971             Path(path / "root/c.py"),
1972             Path(path / "root/child/c.py"),
1973         ]
1974         this_abs = THIS_DIR.resolve()
1975         sources = list(
1976             black.gen_python_files(
1977                 path.iterdir(),
1978                 this_abs,
1979                 include,
1980                 exclude,
1981                 None,
1982                 None,
1983                 report,
1984                 root_gitignore,
1985                 verbose=False,
1986                 quiet=False,
1987             )
1988         )
1989         assert sorted(expected) == sorted(sources)
1990
1991     def test_invalid_gitignore(self) -> None:
1992         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1993         empty_config = path / "pyproject.toml"
1994         result = BlackRunner().invoke(
1995             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1996         )
1997         assert result.exit_code == 1
1998         assert result.stderr_bytes is not None
1999
2000         gitignore = path / ".gitignore"
2001         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2002
2003     def test_invalid_nested_gitignore(self) -> None:
2004         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2005         empty_config = path / "pyproject.toml"
2006         result = BlackRunner().invoke(
2007             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2008         )
2009         assert result.exit_code == 1
2010         assert result.stderr_bytes is not None
2011
2012         gitignore = path / "a" / ".gitignore"
2013         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2014
2015     def test_empty_include(self) -> None:
2016         path = DATA_DIR / "include_exclude_tests"
2017         src = [path]
2018         expected = [
2019             Path(path / "b/exclude/a.pie"),
2020             Path(path / "b/exclude/a.py"),
2021             Path(path / "b/exclude/a.pyi"),
2022             Path(path / "b/dont_exclude/a.pie"),
2023             Path(path / "b/dont_exclude/a.py"),
2024             Path(path / "b/dont_exclude/a.pyi"),
2025             Path(path / "b/.definitely_exclude/a.pie"),
2026             Path(path / "b/.definitely_exclude/a.py"),
2027             Path(path / "b/.definitely_exclude/a.pyi"),
2028             Path(path / ".gitignore"),
2029             Path(path / "pyproject.toml"),
2030         ]
2031         # Setting exclude explicitly to an empty string to block .gitignore usage.
2032         assert_collected_sources(src, expected, include="", exclude="")
2033
2034     def test_extend_exclude(self) -> None:
2035         path = DATA_DIR / "include_exclude_tests"
2036         src = [path]
2037         expected = [
2038             Path(path / "b/exclude/a.py"),
2039             Path(path / "b/dont_exclude/a.py"),
2040         ]
2041         assert_collected_sources(
2042             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2043         )
2044
2045     @pytest.mark.incompatible_with_mypyc
2046     def test_symlink_out_of_root_directory(self) -> None:
2047         path = MagicMock()
2048         root = THIS_DIR.resolve()
2049         child = MagicMock()
2050         include = re.compile(black.DEFAULT_INCLUDES)
2051         exclude = re.compile(black.DEFAULT_EXCLUDES)
2052         report = black.Report()
2053         gitignore = PathSpec.from_lines("gitwildmatch", [])
2054         # `child` should behave like a symlink which resolved path is clearly
2055         # outside of the `root` directory.
2056         path.iterdir.return_value = [child]
2057         child.resolve.return_value = Path("/a/b/c")
2058         child.as_posix.return_value = "/a/b/c"
2059         try:
2060             list(
2061                 black.gen_python_files(
2062                     path.iterdir(),
2063                     root,
2064                     include,
2065                     exclude,
2066                     None,
2067                     None,
2068                     report,
2069                     gitignore,
2070                     verbose=False,
2071                     quiet=False,
2072                 )
2073             )
2074         except ValueError as ve:
2075             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2076         path.iterdir.assert_called_once()
2077         child.resolve.assert_called_once()
2078
2079     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2080     def test_get_sources_with_stdin(self) -> None:
2081         src = ["-"]
2082         expected = ["-"]
2083         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2084
2085     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2086     def test_get_sources_with_stdin_filename(self) -> None:
2087         src = ["-"]
2088         stdin_filename = str(THIS_DIR / "data/collections.py")
2089         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2090         assert_collected_sources(
2091             src,
2092             expected,
2093             exclude=r"/exclude/a\.py",
2094             stdin_filename=stdin_filename,
2095         )
2096
2097     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2098     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2099         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2100         # file being passed directly. This is the same as
2101         # test_exclude_for_issue_1572
2102         path = DATA_DIR / "include_exclude_tests"
2103         src = ["-"]
2104         stdin_filename = str(path / "b/exclude/a.py")
2105         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2106         assert_collected_sources(
2107             src,
2108             expected,
2109             exclude=r"/exclude/|a\.py",
2110             stdin_filename=stdin_filename,
2111         )
2112
2113     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2114     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2115         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2116         # file being passed directly. This is the same as
2117         # test_exclude_for_issue_1572
2118         src = ["-"]
2119         path = THIS_DIR / "data" / "include_exclude_tests"
2120         stdin_filename = str(path / "b/exclude/a.py")
2121         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2122         assert_collected_sources(
2123             src,
2124             expected,
2125             extend_exclude=r"/exclude/|a\.py",
2126             stdin_filename=stdin_filename,
2127         )
2128
2129     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2130     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2131         # Force exclude should exclude the file when passing it through
2132         # stdin_filename
2133         path = THIS_DIR / "data" / "include_exclude_tests"
2134         stdin_filename = str(path / "b/exclude/a.py")
2135         assert_collected_sources(
2136             src=["-"],
2137             expected=[],
2138             force_exclude=r"/exclude/|a\.py",
2139             stdin_filename=stdin_filename,
2140         )
2141
2142
2143 try:
2144     with open(black.__file__, "r", encoding="utf-8") as _bf:
2145         black_source_lines = _bf.readlines()
2146 except UnicodeDecodeError:
2147     if not black.COMPILED:
2148         raise
2149
2150
2151 def tracefunc(
2152     frame: types.FrameType, event: str, arg: Any
2153 ) -> Callable[[types.FrameType, str, Any], Any]:
2154     """Show function calls `from black/__init__.py` as they happen.
2155
2156     Register this with `sys.settrace()` in a test you're debugging.
2157     """
2158     if event != "call":
2159         return tracefunc
2160
2161     stack = len(inspect.stack()) - 19
2162     stack *= 2
2163     filename = frame.f_code.co_filename
2164     lineno = frame.f_lineno
2165     func_sig_lineno = lineno - 1
2166     funcname = black_source_lines[func_sig_lineno].strip()
2167     while funcname.startswith("@"):
2168         func_sig_lineno += 1
2169         funcname = black_source_lines[func_sig_lineno].strip()
2170     if "black/__init__.py" in filename:
2171         print(f"{' ' * stack}{lineno}:{funcname}")
2172     return tracefunc