]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix a string merging/split issue caused by standalone comments. (#3227)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     TypeVar,
29     Union,
30 )
31 from unittest.mock import MagicMock, patch
32
33 import click
34 import pytest
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     get_case_path,
63     read_data,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_detect_debug_f_strings(self) -> None:
314         root = black.lib2to3_parse("""f"{x=}" """)
315         features = black.get_features_used(root)
316         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
317         versions = black.detect_target_versions(root)
318         self.assertIn(black.TargetVersion.PY38, versions)
319
320         root = black.lib2to3_parse(
321             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
322         )
323         features = black.get_features_used(root)
324         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
325
326         # We don't yet support feature version detection in nested f-strings
327         root = black.lib2to3_parse(
328             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
329         )
330         features = black.get_features_used(root)
331         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_string_quotes(self) -> None:
335         source, expected = read_data("miscellaneous", "string_quotes")
336         mode = black.Mode(preview=True)
337         assert_format(source, expected, mode)
338         mode = replace(mode, string_normalization=False)
339         not_normalized = fs(source, mode=mode)
340         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
341         black.assert_equivalent(source, not_normalized)
342         black.assert_stable(source, not_normalized, mode=mode)
343
344     def test_skip_magic_trailing_comma(self) -> None:
345         source, _ = read_data("simple_cases", "expression")
346         expected, _ = read_data(
347             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
348         )
349         tmp_file = Path(black.dump_to_file(source))
350         diff_header = re.compile(
351             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
352             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
353         )
354         try:
355             result = BlackRunner().invoke(
356                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
357             )
358             self.assertEqual(result.exit_code, 0)
359         finally:
360             os.unlink(tmp_file)
361         actual = result.output
362         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
363         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
364         if expected != actual:
365             dump = black.dump_to_file(actual)
366             msg = (
367                 "Expected diff isn't equal to the actual. If you made changes to"
368                 " expression.py and this is an anticipated difference, overwrite"
369                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
370             )
371             self.assertEqual(expected, actual, msg)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_async_as_identifier(self) -> None:
375         source_path = get_case_path("miscellaneous", "async_as_identifier")
376         source, expected = read_data_from_file(source_path)
377         actual = fs(source)
378         self.assertFormatEqual(expected, actual)
379         major, minor = sys.version_info[:2]
380         if major < 3 or (major <= 3 and minor < 7):
381             black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, DEFAULT_MODE)
383         # ensure black can parse this when the target is 3.6
384         self.invokeBlack([str(source_path), "--target-version", "py36"])
385         # but not on 3.7, because async/await is no longer an identifier
386         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_python37(self) -> None:
390         source_path = get_case_path("py_37", "python37")
391         source, expected = read_data_from_file(source_path)
392         actual = fs(source)
393         self.assertFormatEqual(expected, actual)
394         major, minor = sys.version_info[:2]
395         if major > 3 or (major == 3 and minor >= 7):
396             black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, DEFAULT_MODE)
398         # ensure black can parse this when the target is 3.7
399         self.invokeBlack([str(source_path), "--target-version", "py37"])
400         # but not on 3.6, because we use async as a reserved keyword
401         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
402
403     def test_tab_comment_indentation(self) -> None:
404         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
405         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
406         self.assertFormatEqual(contents_spc, fs(contents_spc))
407         self.assertFormatEqual(contents_spc, fs(contents_tab))
408
409         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
410         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
411         self.assertFormatEqual(contents_spc, fs(contents_spc))
412         self.assertFormatEqual(contents_spc, fs(contents_tab))
413
414         # mixed tabs and spaces (valid Python 2 code)
415         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
416         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
417         self.assertFormatEqual(contents_spc, fs(contents_spc))
418         self.assertFormatEqual(contents_spc, fs(contents_tab))
419
420         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
421         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
422         self.assertFormatEqual(contents_spc, fs(contents_spc))
423         self.assertFormatEqual(contents_spc, fs(contents_tab))
424
425     def test_report_verbose(self) -> None:
426         report = Report(verbose=True)
427         out_lines = []
428         err_lines = []
429
430         def out(msg: str, **kwargs: Any) -> None:
431             out_lines.append(msg)
432
433         def err(msg: str, **kwargs: Any) -> None:
434             err_lines.append(msg)
435
436         with patch("black.output._out", out), patch("black.output._err", err):
437             report.done(Path("f1"), black.Changed.NO)
438             self.assertEqual(len(out_lines), 1)
439             self.assertEqual(len(err_lines), 0)
440             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
441             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
442             self.assertEqual(report.return_code, 0)
443             report.done(Path("f2"), black.Changed.YES)
444             self.assertEqual(len(out_lines), 2)
445             self.assertEqual(len(err_lines), 0)
446             self.assertEqual(out_lines[-1], "reformatted f2")
447             self.assertEqual(
448                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
449             )
450             report.done(Path("f3"), black.Changed.CACHED)
451             self.assertEqual(len(out_lines), 3)
452             self.assertEqual(len(err_lines), 0)
453             self.assertEqual(
454                 out_lines[-1], "f3 wasn't modified on disk since last run."
455             )
456             self.assertEqual(
457                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
458             )
459             self.assertEqual(report.return_code, 0)
460             report.check = True
461             self.assertEqual(report.return_code, 1)
462             report.check = False
463             report.failed(Path("e1"), "boom")
464             self.assertEqual(len(out_lines), 3)
465             self.assertEqual(len(err_lines), 1)
466             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
467             self.assertEqual(
468                 unstyle(str(report)),
469                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
470                 " reformat.",
471             )
472             self.assertEqual(report.return_code, 123)
473             report.done(Path("f3"), black.Changed.YES)
474             self.assertEqual(len(out_lines), 4)
475             self.assertEqual(len(err_lines), 1)
476             self.assertEqual(out_lines[-1], "reformatted f3")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.failed(Path("e2"), "boom")
484             self.assertEqual(len(out_lines), 4)
485             self.assertEqual(len(err_lines), 2)
486             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.path_ignored(Path("wat"), "no match")
494             self.assertEqual(len(out_lines), 5)
495             self.assertEqual(len(err_lines), 2)
496             self.assertEqual(out_lines[-1], "wat ignored: no match")
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
500                 " reformat.",
501             )
502             self.assertEqual(report.return_code, 123)
503             report.done(Path("f4"), black.Changed.NO)
504             self.assertEqual(len(out_lines), 6)
505             self.assertEqual(len(err_lines), 2)
506             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
507             self.assertEqual(
508                 unstyle(str(report)),
509                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
510                 " reformat.",
511             )
512             self.assertEqual(report.return_code, 123)
513             report.check = True
514             self.assertEqual(
515                 unstyle(str(report)),
516                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
517                 " would fail to reformat.",
518             )
519             report.check = False
520             report.diff = True
521             self.assertEqual(
522                 unstyle(str(report)),
523                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
524                 " would fail to reformat.",
525             )
526
527     def test_report_quiet(self) -> None:
528         report = Report(quiet=True)
529         out_lines = []
530         err_lines = []
531
532         def out(msg: str, **kwargs: Any) -> None:
533             out_lines.append(msg)
534
535         def err(msg: str, **kwargs: Any) -> None:
536             err_lines.append(msg)
537
538         with patch("black.output._out", out), patch("black.output._err", err):
539             report.done(Path("f1"), black.Changed.NO)
540             self.assertEqual(len(out_lines), 0)
541             self.assertEqual(len(err_lines), 0)
542             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
543             self.assertEqual(report.return_code, 0)
544             report.done(Path("f2"), black.Changed.YES)
545             self.assertEqual(len(out_lines), 0)
546             self.assertEqual(len(err_lines), 0)
547             self.assertEqual(
548                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
549             )
550             report.done(Path("f3"), black.Changed.CACHED)
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 0)
553             self.assertEqual(
554                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
555             )
556             self.assertEqual(report.return_code, 0)
557             report.check = True
558             self.assertEqual(report.return_code, 1)
559             report.check = False
560             report.failed(Path("e1"), "boom")
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 1)
563             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
564             self.assertEqual(
565                 unstyle(str(report)),
566                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
567                 " reformat.",
568             )
569             self.assertEqual(report.return_code, 123)
570             report.done(Path("f3"), black.Changed.YES)
571             self.assertEqual(len(out_lines), 0)
572             self.assertEqual(len(err_lines), 1)
573             self.assertEqual(
574                 unstyle(str(report)),
575                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
576                 " reformat.",
577             )
578             self.assertEqual(report.return_code, 123)
579             report.failed(Path("e2"), "boom")
580             self.assertEqual(len(out_lines), 0)
581             self.assertEqual(len(err_lines), 2)
582             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
586                 " reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.path_ignored(Path("wat"), "no match")
590             self.assertEqual(len(out_lines), 0)
591             self.assertEqual(len(err_lines), 2)
592             self.assertEqual(
593                 unstyle(str(report)),
594                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
595                 " reformat.",
596             )
597             self.assertEqual(report.return_code, 123)
598             report.done(Path("f4"), black.Changed.NO)
599             self.assertEqual(len(out_lines), 0)
600             self.assertEqual(len(err_lines), 2)
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
604                 " reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.check = True
608             self.assertEqual(
609                 unstyle(str(report)),
610                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
611                 " would fail to reformat.",
612             )
613             report.check = False
614             report.diff = True
615             self.assertEqual(
616                 unstyle(str(report)),
617                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
618                 " would fail to reformat.",
619             )
620
621     def test_report_normal(self) -> None:
622         report = black.Report()
623         out_lines = []
624         err_lines = []
625
626         def out(msg: str, **kwargs: Any) -> None:
627             out_lines.append(msg)
628
629         def err(msg: str, **kwargs: Any) -> None:
630             err_lines.append(msg)
631
632         with patch("black.output._out", out), patch("black.output._err", err):
633             report.done(Path("f1"), black.Changed.NO)
634             self.assertEqual(len(out_lines), 0)
635             self.assertEqual(len(err_lines), 0)
636             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
637             self.assertEqual(report.return_code, 0)
638             report.done(Path("f2"), black.Changed.YES)
639             self.assertEqual(len(out_lines), 1)
640             self.assertEqual(len(err_lines), 0)
641             self.assertEqual(out_lines[-1], "reformatted f2")
642             self.assertEqual(
643                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
644             )
645             report.done(Path("f3"), black.Changed.CACHED)
646             self.assertEqual(len(out_lines), 1)
647             self.assertEqual(len(err_lines), 0)
648             self.assertEqual(out_lines[-1], "reformatted f2")
649             self.assertEqual(
650                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
651             )
652             self.assertEqual(report.return_code, 0)
653             report.check = True
654             self.assertEqual(report.return_code, 1)
655             report.check = False
656             report.failed(Path("e1"), "boom")
657             self.assertEqual(len(out_lines), 1)
658             self.assertEqual(len(err_lines), 1)
659             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.done(Path("f3"), black.Changed.YES)
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 1)
669             self.assertEqual(out_lines[-1], "reformatted f3")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.failed(Path("e2"), "boom")
677             self.assertEqual(len(out_lines), 2)
678             self.assertEqual(len(err_lines), 2)
679             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
683                 " reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.path_ignored(Path("wat"), "no match")
687             self.assertEqual(len(out_lines), 2)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(
690                 unstyle(str(report)),
691                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
692                 " reformat.",
693             )
694             self.assertEqual(report.return_code, 123)
695             report.done(Path("f4"), black.Changed.NO)
696             self.assertEqual(len(out_lines), 2)
697             self.assertEqual(len(err_lines), 2)
698             self.assertEqual(
699                 unstyle(str(report)),
700                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
701                 " reformat.",
702             )
703             self.assertEqual(report.return_code, 123)
704             report.check = True
705             self.assertEqual(
706                 unstyle(str(report)),
707                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
708                 " would fail to reformat.",
709             )
710             report.check = False
711             report.diff = True
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
715                 " would fail to reformat.",
716             )
717
718     def test_lib2to3_parse(self) -> None:
719         with self.assertRaises(black.InvalidInput):
720             black.lib2to3_parse("invalid syntax")
721
722         straddling = "x + y"
723         black.lib2to3_parse(straddling)
724         black.lib2to3_parse(straddling, {TargetVersion.PY36})
725
726         py2_only = "print x"
727         with self.assertRaises(black.InvalidInput):
728             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
729
730         py3_only = "exec(x, end=y)"
731         black.lib2to3_parse(py3_only)
732         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
733
734     def test_get_features_used_decorator(self) -> None:
735         # Test the feature detection of new decorator syntax
736         # since this makes some test cases of test_get_features_used()
737         # fails if it fails, this is tested first so that a useful case
738         # is identified
739         simples, relaxed = read_data("miscellaneous", "decorators")
740         # skip explanation comments at the top of the file
741         for simple_test in simples.split("##")[1:]:
742             node = black.lib2to3_parse(simple_test)
743             decorator = str(node.children[0].children[0]).strip()
744             self.assertNotIn(
745                 Feature.RELAXED_DECORATORS,
746                 black.get_features_used(node),
747                 msg=(
748                     f"decorator '{decorator}' follows python<=3.8 syntax"
749                     "but is detected as 3.9+"
750                     # f"The full node is\n{node!r}"
751                 ),
752             )
753         # skip the '# output' comment at the top of the output part
754         for relaxed_test in relaxed.split("##")[1:]:
755             node = black.lib2to3_parse(relaxed_test)
756             decorator = str(node.children[0].children[0]).strip()
757             self.assertIn(
758                 Feature.RELAXED_DECORATORS,
759                 black.get_features_used(node),
760                 msg=(
761                     f"decorator '{decorator}' uses python3.9+ syntax"
762                     "but is detected as python<=3.8"
763                     # f"The full node is\n{node!r}"
764                 ),
765             )
766
767     def test_get_features_used(self) -> None:
768         node = black.lib2to3_parse("def f(*, arg): ...\n")
769         self.assertEqual(black.get_features_used(node), set())
770         node = black.lib2to3_parse("def f(*, arg,): ...\n")
771         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
772         node = black.lib2to3_parse("f(*arg,)\n")
773         self.assertEqual(
774             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
775         )
776         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
777         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
778         node = black.lib2to3_parse("123_456\n")
779         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
780         node = black.lib2to3_parse("123456\n")
781         self.assertEqual(black.get_features_used(node), set())
782         source, expected = read_data("simple_cases", "function")
783         node = black.lib2to3_parse(source)
784         expected_features = {
785             Feature.TRAILING_COMMA_IN_CALL,
786             Feature.TRAILING_COMMA_IN_DEF,
787             Feature.F_STRINGS,
788         }
789         self.assertEqual(black.get_features_used(node), expected_features)
790         node = black.lib2to3_parse(expected)
791         self.assertEqual(black.get_features_used(node), expected_features)
792         source, expected = read_data("simple_cases", "expression")
793         node = black.lib2to3_parse(source)
794         self.assertEqual(black.get_features_used(node), set())
795         node = black.lib2to3_parse(expected)
796         self.assertEqual(black.get_features_used(node), set())
797         node = black.lib2to3_parse("lambda a, /, b: ...")
798         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
799         node = black.lib2to3_parse("def fn(a, /, b): ...")
800         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
801         node = black.lib2to3_parse("def fn(): yield a, b")
802         self.assertEqual(black.get_features_used(node), set())
803         node = black.lib2to3_parse("def fn(): return a, b")
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse("def fn(): yield *b, c")
806         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
807         node = black.lib2to3_parse("def fn(): return a, *b, c")
808         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
809         node = black.lib2to3_parse("x = a, *b, c")
810         self.assertEqual(black.get_features_used(node), set())
811         node = black.lib2to3_parse("x: Any = regular")
812         self.assertEqual(black.get_features_used(node), set())
813         node = black.lib2to3_parse("x: Any = (regular, regular)")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
816         self.assertEqual(black.get_features_used(node), set())
817         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
818         self.assertEqual(
819             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
820         )
821         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
826         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
827         node = black.lib2to3_parse("a[*b]")
828         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
829         node = black.lib2to3_parse("a[x, *y(), z] = t")
830         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
831         node = black.lib2to3_parse("def fn(*args: *T): pass")
832         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
833
834     def test_get_features_used_for_future_flags(self) -> None:
835         for src, features in [
836             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
837             (
838                 "from __future__ import (other, annotations)",
839                 {Feature.FUTURE_ANNOTATIONS},
840             ),
841             ("a = 1 + 2\nfrom something import annotations", set()),
842             ("from __future__ import x, y", set()),
843         ]:
844             with self.subTest(src=src, features=features):
845                 node = black.lib2to3_parse(src)
846                 future_imports = black.get_future_imports(node)
847                 self.assertEqual(
848                     black.get_features_used(node, future_imports=future_imports),
849                     features,
850                 )
851
852     def test_get_future_imports(self) -> None:
853         node = black.lib2to3_parse("\n")
854         self.assertEqual(set(), black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import black\n")
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
858         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
860         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import multiple\nfrom __future__ import imports\n"
863         )
864         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
865         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
868         self.assertEqual({"black"}, black.get_future_imports(node))
869         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse("from some.module import black\n")
872         self.assertEqual(set(), black.get_future_imports(node))
873         node = black.lib2to3_parse(
874             "from __future__ import unicode_literals as _unicode_literals"
875         )
876         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
877         node = black.lib2to3_parse(
878             "from __future__ import unicode_literals as _lol, print"
879         )
880         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
881
882     @pytest.mark.incompatible_with_mypyc
883     def test_debug_visitor(self) -> None:
884         source, _ = read_data("miscellaneous", "debug_visitor")
885         expected, _ = read_data("miscellaneous", "debug_visitor.out")
886         out_lines = []
887         err_lines = []
888
889         def out(msg: str, **kwargs: Any) -> None:
890             out_lines.append(msg)
891
892         def err(msg: str, **kwargs: Any) -> None:
893             err_lines.append(msg)
894
895         with patch("black.debug.out", out):
896             DebugVisitor.show(source)
897         actual = "\n".join(out_lines) + "\n"
898         log_name = ""
899         if expected != actual:
900             log_name = black.dump_to_file(*out_lines)
901         self.assertEqual(
902             expected,
903             actual,
904             f"AST print out is different. Actual version dumped to {log_name}",
905         )
906
907     def test_format_file_contents(self) -> None:
908         empty = ""
909         mode = DEFAULT_MODE
910         with self.assertRaises(black.NothingChanged):
911             black.format_file_contents(empty, mode=mode, fast=False)
912         just_nl = "\n"
913         with self.assertRaises(black.NothingChanged):
914             black.format_file_contents(just_nl, mode=mode, fast=False)
915         same = "j = [1, 2, 3]\n"
916         with self.assertRaises(black.NothingChanged):
917             black.format_file_contents(same, mode=mode, fast=False)
918         different = "j = [1,2,3]"
919         expected = same
920         actual = black.format_file_contents(different, mode=mode, fast=False)
921         self.assertEqual(expected, actual)
922         invalid = "return if you can"
923         with self.assertRaises(black.InvalidInput) as e:
924             black.format_file_contents(invalid, mode=mode, fast=False)
925         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
926
927     def test_endmarker(self) -> None:
928         n = black.lib2to3_parse("\n")
929         self.assertEqual(n.type, black.syms.file_input)
930         self.assertEqual(len(n.children), 1)
931         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
932
933     @pytest.mark.incompatible_with_mypyc
934     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
935     def test_assertFormatEqual(self) -> None:
936         out_lines = []
937         err_lines = []
938
939         def out(msg: str, **kwargs: Any) -> None:
940             out_lines.append(msg)
941
942         def err(msg: str, **kwargs: Any) -> None:
943             err_lines.append(msg)
944
945         with patch("black.output._out", out), patch("black.output._err", err):
946             with self.assertRaises(AssertionError):
947                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
948
949         out_str = "".join(out_lines)
950         self.assertIn("Expected tree:", out_str)
951         self.assertIn("Actual tree:", out_str)
952         self.assertEqual("".join(err_lines), "")
953
954     @event_loop()
955     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
956     def test_works_in_mono_process_only_environment(self) -> None:
957         with cache_dir() as workspace:
958             for f in [
959                 (workspace / "one.py").resolve(),
960                 (workspace / "two.py").resolve(),
961             ]:
962                 f.write_text('print("hello")\n')
963             self.invokeBlack([str(workspace)])
964
965     @event_loop()
966     def test_check_diff_use_together(self) -> None:
967         with cache_dir():
968             # Files which will be reformatted.
969             src1 = get_case_path("miscellaneous", "string_quotes")
970             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
971             # Files which will not be reformatted.
972             src2 = get_case_path("simple_cases", "composition")
973             self.invokeBlack([str(src2), "--diff", "--check"])
974             # Multi file command.
975             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
976
977     def test_no_src_fails(self) -> None:
978         with cache_dir():
979             self.invokeBlack([], exit_code=1)
980
981     def test_src_and_code_fails(self) -> None:
982         with cache_dir():
983             self.invokeBlack([".", "-c", "0"], exit_code=1)
984
985     def test_broken_symlink(self) -> None:
986         with cache_dir() as workspace:
987             symlink = workspace / "broken_link.py"
988             try:
989                 symlink.symlink_to("nonexistent.py")
990             except (OSError, NotImplementedError) as e:
991                 self.skipTest(f"Can't create symlinks: {e}")
992             self.invokeBlack([str(workspace.resolve())])
993
994     def test_single_file_force_pyi(self) -> None:
995         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
996         contents, expected = read_data("miscellaneous", "force_pyi")
997         with cache_dir() as workspace:
998             path = (workspace / "file.py").resolve()
999             with open(path, "w") as fh:
1000                 fh.write(contents)
1001             self.invokeBlack([str(path), "--pyi"])
1002             with open(path, "r") as fh:
1003                 actual = fh.read()
1004             # verify cache with --pyi is separate
1005             pyi_cache = black.read_cache(pyi_mode)
1006             self.assertIn(str(path), pyi_cache)
1007             normal_cache = black.read_cache(DEFAULT_MODE)
1008             self.assertNotIn(str(path), normal_cache)
1009         self.assertFormatEqual(expected, actual)
1010         black.assert_equivalent(contents, actual)
1011         black.assert_stable(contents, actual, pyi_mode)
1012
1013     @event_loop()
1014     def test_multi_file_force_pyi(self) -> None:
1015         reg_mode = DEFAULT_MODE
1016         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1017         contents, expected = read_data("miscellaneous", "force_pyi")
1018         with cache_dir() as workspace:
1019             paths = [
1020                 (workspace / "file1.py").resolve(),
1021                 (workspace / "file2.py").resolve(),
1022             ]
1023             for path in paths:
1024                 with open(path, "w") as fh:
1025                     fh.write(contents)
1026             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1027             for path in paths:
1028                 with open(path, "r") as fh:
1029                     actual = fh.read()
1030                 self.assertEqual(actual, expected)
1031             # verify cache with --pyi is separate
1032             pyi_cache = black.read_cache(pyi_mode)
1033             normal_cache = black.read_cache(reg_mode)
1034             for path in paths:
1035                 self.assertIn(str(path), pyi_cache)
1036                 self.assertNotIn(str(path), normal_cache)
1037
1038     def test_pipe_force_pyi(self) -> None:
1039         source, expected = read_data("miscellaneous", "force_pyi")
1040         result = CliRunner().invoke(
1041             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1042         )
1043         self.assertEqual(result.exit_code, 0)
1044         actual = result.output
1045         self.assertFormatEqual(actual, expected)
1046
1047     def test_single_file_force_py36(self) -> None:
1048         reg_mode = DEFAULT_MODE
1049         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1050         source, expected = read_data("miscellaneous", "force_py36")
1051         with cache_dir() as workspace:
1052             path = (workspace / "file.py").resolve()
1053             with open(path, "w") as fh:
1054                 fh.write(source)
1055             self.invokeBlack([str(path), *PY36_ARGS])
1056             with open(path, "r") as fh:
1057                 actual = fh.read()
1058             # verify cache with --target-version is separate
1059             py36_cache = black.read_cache(py36_mode)
1060             self.assertIn(str(path), py36_cache)
1061             normal_cache = black.read_cache(reg_mode)
1062             self.assertNotIn(str(path), normal_cache)
1063         self.assertEqual(actual, expected)
1064
1065     @event_loop()
1066     def test_multi_file_force_py36(self) -> None:
1067         reg_mode = DEFAULT_MODE
1068         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1069         source, expected = read_data("miscellaneous", "force_py36")
1070         with cache_dir() as workspace:
1071             paths = [
1072                 (workspace / "file1.py").resolve(),
1073                 (workspace / "file2.py").resolve(),
1074             ]
1075             for path in paths:
1076                 with open(path, "w") as fh:
1077                     fh.write(source)
1078             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1079             for path in paths:
1080                 with open(path, "r") as fh:
1081                     actual = fh.read()
1082                 self.assertEqual(actual, expected)
1083             # verify cache with --target-version is separate
1084             pyi_cache = black.read_cache(py36_mode)
1085             normal_cache = black.read_cache(reg_mode)
1086             for path in paths:
1087                 self.assertIn(str(path), pyi_cache)
1088                 self.assertNotIn(str(path), normal_cache)
1089
1090     def test_pipe_force_py36(self) -> None:
1091         source, expected = read_data("miscellaneous", "force_py36")
1092         result = CliRunner().invoke(
1093             black.main,
1094             ["-", "-q", "--target-version=py36"],
1095             input=BytesIO(source.encode("utf8")),
1096         )
1097         self.assertEqual(result.exit_code, 0)
1098         actual = result.output
1099         self.assertFormatEqual(actual, expected)
1100
1101     @pytest.mark.incompatible_with_mypyc
1102     def test_reformat_one_with_stdin(self) -> None:
1103         with patch(
1104             "black.format_stdin_to_stdout",
1105             return_value=lambda *args, **kwargs: black.Changed.YES,
1106         ) as fsts:
1107             report = MagicMock()
1108             path = Path("-")
1109             black.reformat_one(
1110                 path,
1111                 fast=True,
1112                 write_back=black.WriteBack.YES,
1113                 mode=DEFAULT_MODE,
1114                 report=report,
1115             )
1116             fsts.assert_called_once()
1117             report.done.assert_called_with(path, black.Changed.YES)
1118
1119     @pytest.mark.incompatible_with_mypyc
1120     def test_reformat_one_with_stdin_filename(self) -> None:
1121         with patch(
1122             "black.format_stdin_to_stdout",
1123             return_value=lambda *args, **kwargs: black.Changed.YES,
1124         ) as fsts:
1125             report = MagicMock()
1126             p = "foo.py"
1127             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1128             expected = Path(p)
1129             black.reformat_one(
1130                 path,
1131                 fast=True,
1132                 write_back=black.WriteBack.YES,
1133                 mode=DEFAULT_MODE,
1134                 report=report,
1135             )
1136             fsts.assert_called_once_with(
1137                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1138             )
1139             # __BLACK_STDIN_FILENAME__ should have been stripped
1140             report.done.assert_called_with(expected, black.Changed.YES)
1141
1142     @pytest.mark.incompatible_with_mypyc
1143     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1144         with patch(
1145             "black.format_stdin_to_stdout",
1146             return_value=lambda *args, **kwargs: black.Changed.YES,
1147         ) as fsts:
1148             report = MagicMock()
1149             p = "foo.pyi"
1150             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1151             expected = Path(p)
1152             black.reformat_one(
1153                 path,
1154                 fast=True,
1155                 write_back=black.WriteBack.YES,
1156                 mode=DEFAULT_MODE,
1157                 report=report,
1158             )
1159             fsts.assert_called_once_with(
1160                 fast=True,
1161                 write_back=black.WriteBack.YES,
1162                 mode=replace(DEFAULT_MODE, is_pyi=True),
1163             )
1164             # __BLACK_STDIN_FILENAME__ should have been stripped
1165             report.done.assert_called_with(expected, black.Changed.YES)
1166
1167     @pytest.mark.incompatible_with_mypyc
1168     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1169         with patch(
1170             "black.format_stdin_to_stdout",
1171             return_value=lambda *args, **kwargs: black.Changed.YES,
1172         ) as fsts:
1173             report = MagicMock()
1174             p = "foo.ipynb"
1175             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1176             expected = Path(p)
1177             black.reformat_one(
1178                 path,
1179                 fast=True,
1180                 write_back=black.WriteBack.YES,
1181                 mode=DEFAULT_MODE,
1182                 report=report,
1183             )
1184             fsts.assert_called_once_with(
1185                 fast=True,
1186                 write_back=black.WriteBack.YES,
1187                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1188             )
1189             # __BLACK_STDIN_FILENAME__ should have been stripped
1190             report.done.assert_called_with(expected, black.Changed.YES)
1191
1192     @pytest.mark.incompatible_with_mypyc
1193     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1194         with patch(
1195             "black.format_stdin_to_stdout",
1196             return_value=lambda *args, **kwargs: black.Changed.YES,
1197         ) as fsts:
1198             report = MagicMock()
1199             # Even with an existing file, since we are forcing stdin, black
1200             # should output to stdout and not modify the file inplace
1201             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1202             # Make sure is_file actually returns True
1203             self.assertTrue(p.is_file())
1204             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1205             expected = Path(p)
1206             black.reformat_one(
1207                 path,
1208                 fast=True,
1209                 write_back=black.WriteBack.YES,
1210                 mode=DEFAULT_MODE,
1211                 report=report,
1212             )
1213             fsts.assert_called_once()
1214             # __BLACK_STDIN_FILENAME__ should have been stripped
1215             report.done.assert_called_with(expected, black.Changed.YES)
1216
1217     def test_reformat_one_with_stdin_empty(self) -> None:
1218         output = io.StringIO()
1219         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1220             try:
1221                 black.format_stdin_to_stdout(
1222                     fast=True,
1223                     content="",
1224                     write_back=black.WriteBack.YES,
1225                     mode=DEFAULT_MODE,
1226                 )
1227             except io.UnsupportedOperation:
1228                 pass  # StringIO does not support detach
1229             assert output.getvalue() == ""
1230
1231     def test_invalid_cli_regex(self) -> None:
1232         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1233             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1234
1235     def test_required_version_matches_version(self) -> None:
1236         self.invokeBlack(
1237             ["--required-version", black.__version__, "-c", "0"],
1238             exit_code=0,
1239             ignore_config=True,
1240         )
1241
1242     def test_required_version_matches_partial_version(self) -> None:
1243         self.invokeBlack(
1244             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1245             exit_code=0,
1246             ignore_config=True,
1247         )
1248
1249     def test_required_version_does_not_match_on_minor_version(self) -> None:
1250         self.invokeBlack(
1251             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1252             exit_code=1,
1253             ignore_config=True,
1254         )
1255
1256     def test_required_version_does_not_match_version(self) -> None:
1257         result = BlackRunner().invoke(
1258             black.main,
1259             ["--required-version", "20.99b", "-c", "0"],
1260         )
1261         self.assertEqual(result.exit_code, 1)
1262         self.assertIn("required version", result.stderr)
1263
1264     def test_preserves_line_endings(self) -> None:
1265         with TemporaryDirectory() as workspace:
1266             test_file = Path(workspace) / "test.py"
1267             for nl in ["\n", "\r\n"]:
1268                 contents = nl.join(["def f(  ):", "    pass"])
1269                 test_file.write_bytes(contents.encode())
1270                 ff(test_file, write_back=black.WriteBack.YES)
1271                 updated_contents: bytes = test_file.read_bytes()
1272                 self.assertIn(nl.encode(), updated_contents)
1273                 if nl == "\n":
1274                     self.assertNotIn(b"\r\n", updated_contents)
1275
1276     def test_preserves_line_endings_via_stdin(self) -> None:
1277         for nl in ["\n", "\r\n"]:
1278             contents = nl.join(["def f(  ):", "    pass"])
1279             runner = BlackRunner()
1280             result = runner.invoke(
1281                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1282             )
1283             self.assertEqual(result.exit_code, 0)
1284             output = result.stdout_bytes
1285             self.assertIn(nl.encode("utf8"), output)
1286             if nl == "\n":
1287                 self.assertNotIn(b"\r\n", output)
1288
1289     def test_assert_equivalent_different_asts(self) -> None:
1290         with self.assertRaises(AssertionError):
1291             black.assert_equivalent("{}", "None")
1292
1293     def test_shhh_click(self) -> None:
1294         try:
1295             from click import _unicodefun  # type: ignore
1296         except ImportError:
1297             self.skipTest("Incompatible Click version")
1298
1299         if not hasattr(_unicodefun, "_verify_python_env"):
1300             self.skipTest("Incompatible Click version")
1301
1302         # First, let's see if Click is crashing with a preferred ASCII charset.
1303         with patch("locale.getpreferredencoding") as gpe:
1304             gpe.return_value = "ASCII"
1305             with self.assertRaises(RuntimeError):
1306                 _unicodefun._verify_python_env()
1307         # Now, let's silence Click...
1308         black.patch_click()
1309         # ...and confirm it's silent.
1310         with patch("locale.getpreferredencoding") as gpe:
1311             gpe.return_value = "ASCII"
1312             try:
1313                 _unicodefun._verify_python_env()
1314             except RuntimeError as re:
1315                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1316
1317     def test_root_logger_not_used_directly(self) -> None:
1318         def fail(*args: Any, **kwargs: Any) -> None:
1319             self.fail("Record created with root logger")
1320
1321         with patch.multiple(
1322             logging.root,
1323             debug=fail,
1324             info=fail,
1325             warning=fail,
1326             error=fail,
1327             critical=fail,
1328             log=fail,
1329         ):
1330             ff(THIS_DIR / "util.py")
1331
1332     def test_invalid_config_return_code(self) -> None:
1333         tmp_file = Path(black.dump_to_file())
1334         try:
1335             tmp_config = Path(black.dump_to_file())
1336             tmp_config.unlink()
1337             args = ["--config", str(tmp_config), str(tmp_file)]
1338             self.invokeBlack(args, exit_code=2, ignore_config=False)
1339         finally:
1340             tmp_file.unlink()
1341
1342     def test_parse_pyproject_toml(self) -> None:
1343         test_toml_file = THIS_DIR / "test.toml"
1344         config = black.parse_pyproject_toml(str(test_toml_file))
1345         self.assertEqual(config["verbose"], 1)
1346         self.assertEqual(config["check"], "no")
1347         self.assertEqual(config["diff"], "y")
1348         self.assertEqual(config["color"], True)
1349         self.assertEqual(config["line_length"], 79)
1350         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1351         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1352         self.assertEqual(config["exclude"], r"\.pyi?$")
1353         self.assertEqual(config["include"], r"\.py?$")
1354
1355     def test_read_pyproject_toml(self) -> None:
1356         test_toml_file = THIS_DIR / "test.toml"
1357         fake_ctx = FakeContext()
1358         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1359         config = fake_ctx.default_map
1360         self.assertEqual(config["verbose"], "1")
1361         self.assertEqual(config["check"], "no")
1362         self.assertEqual(config["diff"], "y")
1363         self.assertEqual(config["color"], "True")
1364         self.assertEqual(config["line_length"], "79")
1365         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1366         self.assertEqual(config["exclude"], r"\.pyi?$")
1367         self.assertEqual(config["include"], r"\.py?$")
1368
1369     @pytest.mark.incompatible_with_mypyc
1370     def test_find_project_root(self) -> None:
1371         with TemporaryDirectory() as workspace:
1372             root = Path(workspace)
1373             test_dir = root / "test"
1374             test_dir.mkdir()
1375
1376             src_dir = root / "src"
1377             src_dir.mkdir()
1378
1379             root_pyproject = root / "pyproject.toml"
1380             root_pyproject.touch()
1381             src_pyproject = src_dir / "pyproject.toml"
1382             src_pyproject.touch()
1383             src_python = src_dir / "foo.py"
1384             src_python.touch()
1385
1386             self.assertEqual(
1387                 black.find_project_root((src_dir, test_dir)),
1388                 (root.resolve(), "pyproject.toml"),
1389             )
1390             self.assertEqual(
1391                 black.find_project_root((src_dir,)),
1392                 (src_dir.resolve(), "pyproject.toml"),
1393             )
1394             self.assertEqual(
1395                 black.find_project_root((src_python,)),
1396                 (src_dir.resolve(), "pyproject.toml"),
1397             )
1398
1399     @patch(
1400         "black.files.find_user_pyproject_toml",
1401     )
1402     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1403         find_user_pyproject_toml.side_effect = RuntimeError()
1404
1405         with redirect_stderr(io.StringIO()) as stderr:
1406             result = black.files.find_pyproject_toml(
1407                 path_search_start=(str(Path.cwd().root),)
1408             )
1409
1410         assert result is None
1411         err = stderr.getvalue()
1412         assert "Ignoring user configuration" in err
1413
1414     @patch(
1415         "black.files.find_user_pyproject_toml",
1416         black.files.find_user_pyproject_toml.__wrapped__,
1417     )
1418     def test_find_user_pyproject_toml_linux(self) -> None:
1419         if system() == "Windows":
1420             return
1421
1422         # Test if XDG_CONFIG_HOME is checked
1423         with TemporaryDirectory() as workspace:
1424             tmp_user_config = Path(workspace) / "black"
1425             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1426                 self.assertEqual(
1427                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1428                 )
1429
1430         # Test fallback for XDG_CONFIG_HOME
1431         with patch.dict("os.environ"):
1432             os.environ.pop("XDG_CONFIG_HOME", None)
1433             fallback_user_config = Path("~/.config").expanduser() / "black"
1434             self.assertEqual(
1435                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1436             )
1437
1438     def test_find_user_pyproject_toml_windows(self) -> None:
1439         if system() != "Windows":
1440             return
1441
1442         user_config_path = Path.home() / ".black"
1443         self.assertEqual(
1444             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1445         )
1446
1447     def test_bpo_33660_workaround(self) -> None:
1448         if system() == "Windows":
1449             return
1450
1451         # https://bugs.python.org/issue33660
1452         root = Path("/")
1453         with change_directory(root):
1454             path = Path("workspace") / "project"
1455             report = black.Report(verbose=True)
1456             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1457             self.assertEqual(normalized_path, "workspace/project")
1458
1459     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1460         if system() != "Windows":
1461             return
1462
1463         with TemporaryDirectory() as workspace:
1464             root = Path(workspace)
1465             junction_dir = root / "junction"
1466             junction_target_outside_of_root = root / ".."
1467             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1468
1469             report = black.Report(verbose=True)
1470             normalized_path = black.normalize_path_maybe_ignore(
1471                 junction_dir, root, report
1472             )
1473             # Manually delete for Python < 3.8
1474             os.system(f"rmdir {junction_dir}")
1475
1476             self.assertEqual(normalized_path, None)
1477
1478     def test_newline_comment_interaction(self) -> None:
1479         source = "class A:\\\r\n# type: ignore\n pass\n"
1480         output = black.format_str(source, mode=DEFAULT_MODE)
1481         black.assert_stable(source, output, mode=DEFAULT_MODE)
1482
1483     def test_bpo_2142_workaround(self) -> None:
1484         # https://bugs.python.org/issue2142
1485
1486         source, _ = read_data("miscellaneous", "missing_final_newline")
1487         # read_data adds a trailing newline
1488         source = source.rstrip()
1489         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1490         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1491         diff_header = re.compile(
1492             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1493             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1494         )
1495         try:
1496             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1497             self.assertEqual(result.exit_code, 0)
1498         finally:
1499             os.unlink(tmp_file)
1500         actual = result.output
1501         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1502         self.assertEqual(actual, expected)
1503
1504     @staticmethod
1505     def compare_results(
1506         result: click.testing.Result, expected_value: str, expected_exit_code: int
1507     ) -> None:
1508         """Helper method to test the value and exit code of a click Result."""
1509         assert (
1510             result.output == expected_value
1511         ), "The output did not match the expected value."
1512         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1513
1514     def test_code_option(self) -> None:
1515         """Test the code option with no changes."""
1516         code = 'print("Hello world")\n'
1517         args = ["--code", code]
1518         result = CliRunner().invoke(black.main, args)
1519
1520         self.compare_results(result, code, 0)
1521
1522     def test_code_option_changed(self) -> None:
1523         """Test the code option when changes are required."""
1524         code = "print('hello world')"
1525         formatted = black.format_str(code, mode=DEFAULT_MODE)
1526
1527         args = ["--code", code]
1528         result = CliRunner().invoke(black.main, args)
1529
1530         self.compare_results(result, formatted, 0)
1531
1532     def test_code_option_check(self) -> None:
1533         """Test the code option when check is passed."""
1534         args = ["--check", "--code", 'print("Hello world")\n']
1535         result = CliRunner().invoke(black.main, args)
1536         self.compare_results(result, "", 0)
1537
1538     def test_code_option_check_changed(self) -> None:
1539         """Test the code option when changes are required, and check is passed."""
1540         args = ["--check", "--code", "print('hello world')"]
1541         result = CliRunner().invoke(black.main, args)
1542         self.compare_results(result, "", 1)
1543
1544     def test_code_option_diff(self) -> None:
1545         """Test the code option when diff is passed."""
1546         code = "print('hello world')"
1547         formatted = black.format_str(code, mode=DEFAULT_MODE)
1548         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1549
1550         args = ["--diff", "--code", code]
1551         result = CliRunner().invoke(black.main, args)
1552
1553         # Remove time from diff
1554         output = DIFF_TIME.sub("", result.output)
1555
1556         assert output == result_diff, "The output did not match the expected value."
1557         assert result.exit_code == 0, "The exit code is incorrect."
1558
1559     def test_code_option_color_diff(self) -> None:
1560         """Test the code option when color and diff are passed."""
1561         code = "print('hello world')"
1562         formatted = black.format_str(code, mode=DEFAULT_MODE)
1563
1564         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1565         result_diff = color_diff(result_diff)
1566
1567         args = ["--diff", "--color", "--code", code]
1568         result = CliRunner().invoke(black.main, args)
1569
1570         # Remove time from diff
1571         output = DIFF_TIME.sub("", result.output)
1572
1573         assert output == result_diff, "The output did not match the expected value."
1574         assert result.exit_code == 0, "The exit code is incorrect."
1575
1576     @pytest.mark.incompatible_with_mypyc
1577     def test_code_option_safe(self) -> None:
1578         """Test that the code option throws an error when the sanity checks fail."""
1579         # Patch black.assert_equivalent to ensure the sanity checks fail
1580         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1581             code = 'print("Hello world")'
1582             error_msg = f"{code}\nerror: cannot format <string>: \n"
1583
1584             args = ["--safe", "--code", code]
1585             result = CliRunner().invoke(black.main, args)
1586
1587             self.compare_results(result, error_msg, 123)
1588
1589     def test_code_option_fast(self) -> None:
1590         """Test that the code option ignores errors when the sanity checks fail."""
1591         # Patch black.assert_equivalent to ensure the sanity checks fail
1592         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1593             code = 'print("Hello world")'
1594             formatted = black.format_str(code, mode=DEFAULT_MODE)
1595
1596             args = ["--fast", "--code", code]
1597             result = CliRunner().invoke(black.main, args)
1598
1599             self.compare_results(result, formatted, 0)
1600
1601     @pytest.mark.incompatible_with_mypyc
1602     def test_code_option_config(self) -> None:
1603         """
1604         Test that the code option finds the pyproject.toml in the current directory.
1605         """
1606         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1607             args = ["--code", "print"]
1608             # This is the only directory known to contain a pyproject.toml
1609             with change_directory(PROJECT_ROOT):
1610                 CliRunner().invoke(black.main, args)
1611                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1612
1613             assert (
1614                 len(parse.mock_calls) >= 1
1615             ), "Expected config parse to be called with the current directory."
1616
1617             _, call_args, _ = parse.mock_calls[0]
1618             assert (
1619                 call_args[0].lower() == str(pyproject_path).lower()
1620             ), "Incorrect config loaded."
1621
1622     @pytest.mark.incompatible_with_mypyc
1623     def test_code_option_parent_config(self) -> None:
1624         """
1625         Test that the code option finds the pyproject.toml in the parent directory.
1626         """
1627         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1628             with change_directory(THIS_DIR):
1629                 args = ["--code", "print"]
1630                 CliRunner().invoke(black.main, args)
1631
1632                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1633                 assert (
1634                     len(parse.mock_calls) >= 1
1635                 ), "Expected config parse to be called with the current directory."
1636
1637                 _, call_args, _ = parse.mock_calls[0]
1638                 assert (
1639                     call_args[0].lower() == str(pyproject_path).lower()
1640                 ), "Incorrect config loaded."
1641
1642     def test_for_handled_unexpected_eof_error(self) -> None:
1643         """
1644         Test that an unexpected EOF SyntaxError is nicely presented.
1645         """
1646         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1647             black.lib2to3_parse("print(", {})
1648
1649         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1650
1651     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1652         with pytest.raises(AssertionError) as err:
1653             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1654
1655         err.match("--safe")
1656         # Unfortunately the SyntaxError message has changed in newer versions so we
1657         # can't match it directly.
1658         err.match("invalid character")
1659         err.match(r"\(<unknown>, line 1\)")
1660
1661
1662 class TestCaching:
1663     def test_get_cache_dir(
1664         self,
1665         tmp_path: Path,
1666         monkeypatch: pytest.MonkeyPatch,
1667     ) -> None:
1668         # Create multiple cache directories
1669         workspace1 = tmp_path / "ws1"
1670         workspace1.mkdir()
1671         workspace2 = tmp_path / "ws2"
1672         workspace2.mkdir()
1673
1674         # Force user_cache_dir to use the temporary directory for easier assertions
1675         patch_user_cache_dir = patch(
1676             target="black.cache.user_cache_dir",
1677             autospec=True,
1678             return_value=str(workspace1),
1679         )
1680
1681         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1682         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1683         with patch_user_cache_dir:
1684             assert get_cache_dir() == workspace1
1685
1686         # If it is set, use the path provided in the env var.
1687         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1688         assert get_cache_dir() == workspace2
1689
1690     def test_cache_broken_file(self) -> None:
1691         mode = DEFAULT_MODE
1692         with cache_dir() as workspace:
1693             cache_file = get_cache_file(mode)
1694             cache_file.write_text("this is not a pickle")
1695             assert black.read_cache(mode) == {}
1696             src = (workspace / "test.py").resolve()
1697             src.write_text("print('hello')")
1698             invokeBlack([str(src)])
1699             cache = black.read_cache(mode)
1700             assert str(src) in cache
1701
1702     def test_cache_single_file_already_cached(self) -> None:
1703         mode = DEFAULT_MODE
1704         with cache_dir() as workspace:
1705             src = (workspace / "test.py").resolve()
1706             src.write_text("print('hello')")
1707             black.write_cache({}, [src], mode)
1708             invokeBlack([str(src)])
1709             assert src.read_text() == "print('hello')"
1710
1711     @event_loop()
1712     def test_cache_multiple_files(self) -> None:
1713         mode = DEFAULT_MODE
1714         with cache_dir() as workspace, patch(
1715             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1716         ):
1717             one = (workspace / "one.py").resolve()
1718             with one.open("w") as fobj:
1719                 fobj.write("print('hello')")
1720             two = (workspace / "two.py").resolve()
1721             with two.open("w") as fobj:
1722                 fobj.write("print('hello')")
1723             black.write_cache({}, [one], mode)
1724             invokeBlack([str(workspace)])
1725             with one.open("r") as fobj:
1726                 assert fobj.read() == "print('hello')"
1727             with two.open("r") as fobj:
1728                 assert fobj.read() == 'print("hello")\n'
1729             cache = black.read_cache(mode)
1730             assert str(one) in cache
1731             assert str(two) in cache
1732
1733     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1734     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1735         mode = DEFAULT_MODE
1736         with cache_dir() as workspace:
1737             src = (workspace / "test.py").resolve()
1738             with src.open("w") as fobj:
1739                 fobj.write("print('hello')")
1740             with patch("black.read_cache") as read_cache, patch(
1741                 "black.write_cache"
1742             ) as write_cache:
1743                 cmd = [str(src), "--diff"]
1744                 if color:
1745                     cmd.append("--color")
1746                 invokeBlack(cmd)
1747                 cache_file = get_cache_file(mode)
1748                 assert cache_file.exists() is False
1749                 write_cache.assert_not_called()
1750                 read_cache.assert_not_called()
1751
1752     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1753     @event_loop()
1754     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1755         with cache_dir() as workspace:
1756             for tag in range(0, 4):
1757                 src = (workspace / f"test{tag}.py").resolve()
1758                 with src.open("w") as fobj:
1759                     fobj.write("print('hello')")
1760             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1761                 cmd = ["--diff", str(workspace)]
1762                 if color:
1763                     cmd.append("--color")
1764                 invokeBlack(cmd, exit_code=0)
1765                 # this isn't quite doing what we want, but if it _isn't_
1766                 # called then we cannot be using the lock it provides
1767                 mgr.assert_called()
1768
1769     def test_no_cache_when_stdin(self) -> None:
1770         mode = DEFAULT_MODE
1771         with cache_dir():
1772             result = CliRunner().invoke(
1773                 black.main, ["-"], input=BytesIO(b"print('hello')")
1774             )
1775             assert not result.exit_code
1776             cache_file = get_cache_file(mode)
1777             assert not cache_file.exists()
1778
1779     def test_read_cache_no_cachefile(self) -> None:
1780         mode = DEFAULT_MODE
1781         with cache_dir():
1782             assert black.read_cache(mode) == {}
1783
1784     def test_write_cache_read_cache(self) -> None:
1785         mode = DEFAULT_MODE
1786         with cache_dir() as workspace:
1787             src = (workspace / "test.py").resolve()
1788             src.touch()
1789             black.write_cache({}, [src], mode)
1790             cache = black.read_cache(mode)
1791             assert str(src) in cache
1792             assert cache[str(src)] == black.get_cache_info(src)
1793
1794     def test_filter_cached(self) -> None:
1795         with TemporaryDirectory() as workspace:
1796             path = Path(workspace)
1797             uncached = (path / "uncached").resolve()
1798             cached = (path / "cached").resolve()
1799             cached_but_changed = (path / "changed").resolve()
1800             uncached.touch()
1801             cached.touch()
1802             cached_but_changed.touch()
1803             cache = {
1804                 str(cached): black.get_cache_info(cached),
1805                 str(cached_but_changed): (0.0, 0),
1806             }
1807             todo, done = black.filter_cached(
1808                 cache, {uncached, cached, cached_but_changed}
1809             )
1810             assert todo == {uncached, cached_but_changed}
1811             assert done == {cached}
1812
1813     def test_write_cache_creates_directory_if_needed(self) -> None:
1814         mode = DEFAULT_MODE
1815         with cache_dir(exists=False) as workspace:
1816             assert not workspace.exists()
1817             black.write_cache({}, [], mode)
1818             assert workspace.exists()
1819
1820     @event_loop()
1821     def test_failed_formatting_does_not_get_cached(self) -> None:
1822         mode = DEFAULT_MODE
1823         with cache_dir() as workspace, patch(
1824             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1825         ):
1826             failing = (workspace / "failing.py").resolve()
1827             with failing.open("w") as fobj:
1828                 fobj.write("not actually python")
1829             clean = (workspace / "clean.py").resolve()
1830             with clean.open("w") as fobj:
1831                 fobj.write('print("hello")\n')
1832             invokeBlack([str(workspace)], exit_code=123)
1833             cache = black.read_cache(mode)
1834             assert str(failing) not in cache
1835             assert str(clean) in cache
1836
1837     def test_write_cache_write_fail(self) -> None:
1838         mode = DEFAULT_MODE
1839         with cache_dir(), patch.object(Path, "open") as mock:
1840             mock.side_effect = OSError
1841             black.write_cache({}, [], mode)
1842
1843     def test_read_cache_line_lengths(self) -> None:
1844         mode = DEFAULT_MODE
1845         short_mode = replace(DEFAULT_MODE, line_length=1)
1846         with cache_dir() as workspace:
1847             path = (workspace / "file.py").resolve()
1848             path.touch()
1849             black.write_cache({}, [path], mode)
1850             one = black.read_cache(mode)
1851             assert str(path) in one
1852             two = black.read_cache(short_mode)
1853             assert str(path) not in two
1854
1855
1856 def assert_collected_sources(
1857     src: Sequence[Union[str, Path]],
1858     expected: Sequence[Union[str, Path]],
1859     *,
1860     ctx: Optional[FakeContext] = None,
1861     exclude: Optional[str] = None,
1862     include: Optional[str] = None,
1863     extend_exclude: Optional[str] = None,
1864     force_exclude: Optional[str] = None,
1865     stdin_filename: Optional[str] = None,
1866 ) -> None:
1867     gs_src = tuple(str(Path(s)) for s in src)
1868     gs_expected = [Path(s) for s in expected]
1869     gs_exclude = None if exclude is None else compile_pattern(exclude)
1870     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1871     gs_extend_exclude = (
1872         None if extend_exclude is None else compile_pattern(extend_exclude)
1873     )
1874     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1875     collected = black.get_sources(
1876         ctx=ctx or FakeContext(),
1877         src=gs_src,
1878         quiet=False,
1879         verbose=False,
1880         include=gs_include,
1881         exclude=gs_exclude,
1882         extend_exclude=gs_extend_exclude,
1883         force_exclude=gs_force_exclude,
1884         report=black.Report(),
1885         stdin_filename=stdin_filename,
1886     )
1887     assert sorted(collected) == sorted(gs_expected)
1888
1889
1890 class TestFileCollection:
1891     def test_include_exclude(self) -> None:
1892         path = THIS_DIR / "data" / "include_exclude_tests"
1893         src = [path]
1894         expected = [
1895             Path(path / "b/dont_exclude/a.py"),
1896             Path(path / "b/dont_exclude/a.pyi"),
1897         ]
1898         assert_collected_sources(
1899             src,
1900             expected,
1901             include=r"\.pyi?$",
1902             exclude=r"/exclude/|/\.definitely_exclude/",
1903         )
1904
1905     def test_gitignore_used_as_default(self) -> None:
1906         base = Path(DATA_DIR / "include_exclude_tests")
1907         expected = [
1908             base / "b/.definitely_exclude/a.py",
1909             base / "b/.definitely_exclude/a.pyi",
1910         ]
1911         src = [base / "b/"]
1912         ctx = FakeContext()
1913         ctx.obj["root"] = base
1914         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1915
1916     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1917     def test_exclude_for_issue_1572(self) -> None:
1918         # Exclude shouldn't touch files that were explicitly given to Black through the
1919         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1920         # https://github.com/psf/black/issues/1572
1921         path = DATA_DIR / "include_exclude_tests"
1922         src = [path / "b/exclude/a.py"]
1923         expected = [path / "b/exclude/a.py"]
1924         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1925
1926     def test_gitignore_exclude(self) -> None:
1927         path = THIS_DIR / "data" / "include_exclude_tests"
1928         include = re.compile(r"\.pyi?$")
1929         exclude = re.compile(r"")
1930         report = black.Report()
1931         gitignore = PathSpec.from_lines(
1932             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1933         )
1934         sources: List[Path] = []
1935         expected = [
1936             Path(path / "b/dont_exclude/a.py"),
1937             Path(path / "b/dont_exclude/a.pyi"),
1938         ]
1939         this_abs = THIS_DIR.resolve()
1940         sources.extend(
1941             black.gen_python_files(
1942                 path.iterdir(),
1943                 this_abs,
1944                 include,
1945                 exclude,
1946                 None,
1947                 None,
1948                 report,
1949                 gitignore,
1950                 verbose=False,
1951                 quiet=False,
1952             )
1953         )
1954         assert sorted(expected) == sorted(sources)
1955
1956     def test_nested_gitignore(self) -> None:
1957         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1958         include = re.compile(r"\.pyi?$")
1959         exclude = re.compile(r"")
1960         root_gitignore = black.files.get_gitignore(path)
1961         report = black.Report()
1962         expected: List[Path] = [
1963             Path(path / "x.py"),
1964             Path(path / "root/b.py"),
1965             Path(path / "root/c.py"),
1966             Path(path / "root/child/c.py"),
1967         ]
1968         this_abs = THIS_DIR.resolve()
1969         sources = list(
1970             black.gen_python_files(
1971                 path.iterdir(),
1972                 this_abs,
1973                 include,
1974                 exclude,
1975                 None,
1976                 None,
1977                 report,
1978                 root_gitignore,
1979                 verbose=False,
1980                 quiet=False,
1981             )
1982         )
1983         assert sorted(expected) == sorted(sources)
1984
1985     def test_invalid_gitignore(self) -> None:
1986         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1987         empty_config = path / "pyproject.toml"
1988         result = BlackRunner().invoke(
1989             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1990         )
1991         assert result.exit_code == 1
1992         assert result.stderr_bytes is not None
1993
1994         gitignore = path / ".gitignore"
1995         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1996
1997     def test_invalid_nested_gitignore(self) -> None:
1998         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1999         empty_config = path / "pyproject.toml"
2000         result = BlackRunner().invoke(
2001             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2002         )
2003         assert result.exit_code == 1
2004         assert result.stderr_bytes is not None
2005
2006         gitignore = path / "a" / ".gitignore"
2007         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2008
2009     def test_empty_include(self) -> None:
2010         path = DATA_DIR / "include_exclude_tests"
2011         src = [path]
2012         expected = [
2013             Path(path / "b/exclude/a.pie"),
2014             Path(path / "b/exclude/a.py"),
2015             Path(path / "b/exclude/a.pyi"),
2016             Path(path / "b/dont_exclude/a.pie"),
2017             Path(path / "b/dont_exclude/a.py"),
2018             Path(path / "b/dont_exclude/a.pyi"),
2019             Path(path / "b/.definitely_exclude/a.pie"),
2020             Path(path / "b/.definitely_exclude/a.py"),
2021             Path(path / "b/.definitely_exclude/a.pyi"),
2022             Path(path / ".gitignore"),
2023             Path(path / "pyproject.toml"),
2024         ]
2025         # Setting exclude explicitly to an empty string to block .gitignore usage.
2026         assert_collected_sources(src, expected, include="", exclude="")
2027
2028     def test_extend_exclude(self) -> None:
2029         path = DATA_DIR / "include_exclude_tests"
2030         src = [path]
2031         expected = [
2032             Path(path / "b/exclude/a.py"),
2033             Path(path / "b/dont_exclude/a.py"),
2034         ]
2035         assert_collected_sources(
2036             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2037         )
2038
2039     @pytest.mark.incompatible_with_mypyc
2040     def test_symlink_out_of_root_directory(self) -> None:
2041         path = MagicMock()
2042         root = THIS_DIR.resolve()
2043         child = MagicMock()
2044         include = re.compile(black.DEFAULT_INCLUDES)
2045         exclude = re.compile(black.DEFAULT_EXCLUDES)
2046         report = black.Report()
2047         gitignore = PathSpec.from_lines("gitwildmatch", [])
2048         # `child` should behave like a symlink which resolved path is clearly
2049         # outside of the `root` directory.
2050         path.iterdir.return_value = [child]
2051         child.resolve.return_value = Path("/a/b/c")
2052         child.as_posix.return_value = "/a/b/c"
2053         try:
2054             list(
2055                 black.gen_python_files(
2056                     path.iterdir(),
2057                     root,
2058                     include,
2059                     exclude,
2060                     None,
2061                     None,
2062                     report,
2063                     gitignore,
2064                     verbose=False,
2065                     quiet=False,
2066                 )
2067             )
2068         except ValueError as ve:
2069             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2070         path.iterdir.assert_called_once()
2071         child.resolve.assert_called_once()
2072
2073     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2074     def test_get_sources_with_stdin(self) -> None:
2075         src = ["-"]
2076         expected = ["-"]
2077         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2078
2079     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2080     def test_get_sources_with_stdin_filename(self) -> None:
2081         src = ["-"]
2082         stdin_filename = str(THIS_DIR / "data/collections.py")
2083         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2084         assert_collected_sources(
2085             src,
2086             expected,
2087             exclude=r"/exclude/a\.py",
2088             stdin_filename=stdin_filename,
2089         )
2090
2091     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2092     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2093         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2094         # file being passed directly. This is the same as
2095         # test_exclude_for_issue_1572
2096         path = DATA_DIR / "include_exclude_tests"
2097         src = ["-"]
2098         stdin_filename = str(path / "b/exclude/a.py")
2099         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2100         assert_collected_sources(
2101             src,
2102             expected,
2103             exclude=r"/exclude/|a\.py",
2104             stdin_filename=stdin_filename,
2105         )
2106
2107     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2108     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2109         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2110         # file being passed directly. This is the same as
2111         # test_exclude_for_issue_1572
2112         src = ["-"]
2113         path = THIS_DIR / "data" / "include_exclude_tests"
2114         stdin_filename = str(path / "b/exclude/a.py")
2115         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2116         assert_collected_sources(
2117             src,
2118             expected,
2119             extend_exclude=r"/exclude/|a\.py",
2120             stdin_filename=stdin_filename,
2121         )
2122
2123     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2124     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2125         # Force exclude should exclude the file when passing it through
2126         # stdin_filename
2127         path = THIS_DIR / "data" / "include_exclude_tests"
2128         stdin_filename = str(path / "b/exclude/a.py")
2129         assert_collected_sources(
2130             src=["-"],
2131             expected=[],
2132             force_exclude=r"/exclude/|a\.py",
2133             stdin_filename=stdin_filename,
2134         )
2135
2136
2137 try:
2138     with open(black.__file__, "r", encoding="utf-8") as _bf:
2139         black_source_lines = _bf.readlines()
2140 except UnicodeDecodeError:
2141     if not black.COMPILED:
2142         raise
2143
2144
2145 def tracefunc(
2146     frame: types.FrameType, event: str, arg: Any
2147 ) -> Callable[[types.FrameType, str, Any], Any]:
2148     """Show function calls `from black/__init__.py` as they happen.
2149
2150     Register this with `sys.settrace()` in a test you're debugging.
2151     """
2152     if event != "call":
2153         return tracefunc
2154
2155     stack = len(inspect.stack()) - 19
2156     stack *= 2
2157     filename = frame.f_code.co_filename
2158     lineno = frame.f_lineno
2159     func_sig_lineno = lineno - 1
2160     funcname = black_source_lines[func_sig_lineno].strip()
2161     while funcname.startswith("@"):
2162         func_sig_lineno += 1
2163         funcname = black_source_lines[func_sig_lineno].strip()
2164     if "black/__init__.py" in filename:
2165         print(f"{' ' * stack}{lineno}:{funcname}")
2166     return tracefunc