]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump furo from 2022.3.4 to 2022.4.7 in /docs (#3003)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
67 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
68 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
69 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
70 T = TypeVar("T")
71 R = TypeVar("R")
72
73 # Match the time output in a diff, but nothing else
74 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.cache.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop() -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     loop = policy.new_event_loop()
91     asyncio.set_event_loop(loop)
92     try:
93         yield
94
95     finally:
96         loop.close()
97
98
99 class FakeContext(click.Context):
100     """A fake click Context for when calling functions that need it."""
101
102     def __init__(self) -> None:
103         self.default_map: Dict[str, Any] = {}
104         # Dummy root, since most of the tests don't care about it
105         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
106
107
108 class FakeParameter(click.Parameter):
109     """A fake click Parameter for when calling functions that need it."""
110
111     def __init__(self) -> None:
112         pass
113
114
115 class BlackRunner(CliRunner):
116     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
117
118     def __init__(self) -> None:
119         super().__init__(mix_stderr=False)
120
121
122 def invokeBlack(
123     args: List[str], exit_code: int = 0, ignore_config: bool = True
124 ) -> None:
125     runner = BlackRunner()
126     if ignore_config:
127         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
128     result = runner.invoke(black.main, args, catch_exceptions=False)
129     assert result.stdout_bytes is not None
130     assert result.stderr_bytes is not None
131     msg = (
132         f"Failed with args: {args}\n"
133         f"stdout: {result.stdout_bytes.decode()!r}\n"
134         f"stderr: {result.stderr_bytes.decode()!r}\n"
135         f"exception: {result.exception}"
136     )
137     assert result.exit_code == exit_code, msg
138
139
140 class BlackTestCase(BlackBaseTestCase):
141     invokeBlack = staticmethod(invokeBlack)
142
143     def test_empty_ff(self) -> None:
144         expected = ""
145         tmp_file = Path(black.dump_to_file())
146         try:
147             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
148             with open(tmp_file, encoding="utf8") as f:
149                 actual = f.read()
150         finally:
151             os.unlink(tmp_file)
152         self.assertFormatEqual(expected, actual)
153
154     def test_experimental_string_processing_warns(self) -> None:
155         self.assertWarns(
156             black.mode.Deprecated, black.Mode, experimental_string_processing=True
157         )
158
159     def test_piping(self) -> None:
160         source, expected = read_data("src/black/__init__", data=False)
161         result = BlackRunner().invoke(
162             black.main,
163             [
164                 "-",
165                 "--fast",
166                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
167                 f"--config={EMPTY_CONFIG}",
168             ],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         if source != result.output:
174             black.assert_equivalent(source, result.output)
175             black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("expression.py")
183         expected, _ = read_data("expression.diff")
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={EMPTY_CONFIG}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         args = [
202             "-",
203             "--fast",
204             f"--line-length={black.DEFAULT_LINE_LENGTH}",
205             "--diff",
206             "--color",
207             f"--config={EMPTY_CONFIG}",
208         ]
209         result = BlackRunner().invoke(
210             black.main, args, input=BytesIO(source.encode("utf8"))
211         )
212         actual = result.output
213         # Again, the contents are checked in a different test, so only look for colors.
214         self.assertIn("\033[1m", actual)
215         self.assertIn("\033[36m", actual)
216         self.assertIn("\033[32m", actual)
217         self.assertIn("\033[31m", actual)
218         self.assertIn("\033[0m", actual)
219
220     @patch("black.dump_to_file", dump_to_stderr)
221     def _test_wip(self) -> None:
222         source, expected = read_data("wip")
223         sys.settrace(tracefunc)
224         mode = replace(
225             DEFAULT_MODE,
226             experimental_string_processing=False,
227             target_versions={black.TargetVersion.PY38},
228         )
229         actual = fs(source, mode=mode)
230         sys.settrace(None)
231         self.assertFormatEqual(expected, actual)
232         black.assert_equivalent(source, actual)
233         black.assert_stable(source, actual, black.FileMode())
234
235     def test_pep_572_version_detection(self) -> None:
236         source, _ = read_data("pep_572")
237         root = black.lib2to3_parse(source)
238         features = black.get_features_used(root)
239         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
240         versions = black.detect_target_versions(root)
241         self.assertIn(black.TargetVersion.PY38, versions)
242
243     def test_expression_ff(self) -> None:
244         source, expected = read_data("expression")
245         tmp_file = Path(black.dump_to_file(source))
246         try:
247             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
248             with open(tmp_file, encoding="utf8") as f:
249                 actual = f.read()
250         finally:
251             os.unlink(tmp_file)
252         self.assertFormatEqual(expected, actual)
253         with patch("black.dump_to_file", dump_to_stderr):
254             black.assert_equivalent(source, actual)
255             black.assert_stable(source, actual, DEFAULT_MODE)
256
257     def test_expression_diff(self) -> None:
258         source, _ = read_data("expression.py")
259         expected, _ = read_data("expression.diff")
260         tmp_file = Path(black.dump_to_file(source))
261         diff_header = re.compile(
262             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
263             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
264         )
265         try:
266             result = BlackRunner().invoke(
267                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
268             )
269             self.assertEqual(result.exit_code, 0)
270         finally:
271             os.unlink(tmp_file)
272         actual = result.output
273         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
274         if expected != actual:
275             dump = black.dump_to_file(actual)
276             msg = (
277                 "Expected diff isn't equal to the actual. If you made changes to"
278                 " expression.py and this is an anticipated difference, overwrite"
279                 f" tests/data/expression.diff with {dump}"
280             )
281             self.assertEqual(expected, actual, msg)
282
283     def test_expression_diff_with_color(self) -> None:
284         source, _ = read_data("expression.py")
285         expected, _ = read_data("expression.diff")
286         tmp_file = Path(black.dump_to_file(source))
287         try:
288             result = BlackRunner().invoke(
289                 black.main,
290                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
291             )
292         finally:
293             os.unlink(tmp_file)
294         actual = result.output
295         # We check the contents of the diff in `test_expression_diff`. All
296         # we need to check here is that color codes exist in the result.
297         self.assertIn("\033[1m", actual)
298         self.assertIn("\033[36m", actual)
299         self.assertIn("\033[32m", actual)
300         self.assertIn("\033[31m", actual)
301         self.assertIn("\033[0m", actual)
302
303     def test_detect_pos_only_arguments(self) -> None:
304         source, _ = read_data("pep_570")
305         root = black.lib2to3_parse(source)
306         features = black.get_features_used(root)
307         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
308         versions = black.detect_target_versions(root)
309         self.assertIn(black.TargetVersion.PY38, versions)
310
311     @patch("black.dump_to_file", dump_to_stderr)
312     def test_string_quotes(self) -> None:
313         source, expected = read_data("string_quotes")
314         mode = black.Mode(preview=True)
315         assert_format(source, expected, mode)
316         mode = replace(mode, string_normalization=False)
317         not_normalized = fs(source, mode=mode)
318         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
319         black.assert_equivalent(source, not_normalized)
320         black.assert_stable(source, not_normalized, mode=mode)
321
322     def test_skip_magic_trailing_comma(self) -> None:
323         source, _ = read_data("expression.py")
324         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         diff_header = re.compile(
327             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
328             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
329         )
330         try:
331             result = BlackRunner().invoke(
332                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
333             )
334             self.assertEqual(result.exit_code, 0)
335         finally:
336             os.unlink(tmp_file)
337         actual = result.output
338         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
339         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
340         if expected != actual:
341             dump = black.dump_to_file(actual)
342             msg = (
343                 "Expected diff isn't equal to the actual. If you made changes to"
344                 " expression.py and this is an anticipated difference, overwrite"
345                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
346             )
347             self.assertEqual(expected, actual, msg)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_async_as_identifier(self) -> None:
351         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
352         source, expected = read_data("async_as_identifier")
353         actual = fs(source)
354         self.assertFormatEqual(expected, actual)
355         major, minor = sys.version_info[:2]
356         if major < 3 or (major <= 3 and minor < 7):
357             black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, DEFAULT_MODE)
359         # ensure black can parse this when the target is 3.6
360         self.invokeBlack([str(source_path), "--target-version", "py36"])
361         # but not on 3.7, because async/await is no longer an identifier
362         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_python37(self) -> None:
366         source_path = (THIS_DIR / "data" / "python37.py").resolve()
367         source, expected = read_data("python37")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         major, minor = sys.version_info[:2]
371         if major > 3 or (major == 3 and minor >= 7):
372             black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, DEFAULT_MODE)
374         # ensure black can parse this when the target is 3.7
375         self.invokeBlack([str(source_path), "--target-version", "py37"])
376         # but not on 3.6, because we use async as a reserved keyword
377         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
378
379     def test_tab_comment_indentation(self) -> None:
380         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
381         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
382         self.assertFormatEqual(contents_spc, fs(contents_spc))
383         self.assertFormatEqual(contents_spc, fs(contents_tab))
384
385         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
386         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
387         self.assertFormatEqual(contents_spc, fs(contents_spc))
388         self.assertFormatEqual(contents_spc, fs(contents_tab))
389
390         # mixed tabs and spaces (valid Python 2 code)
391         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
392         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
393         self.assertFormatEqual(contents_spc, fs(contents_spc))
394         self.assertFormatEqual(contents_spc, fs(contents_tab))
395
396         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
397         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
398         self.assertFormatEqual(contents_spc, fs(contents_spc))
399         self.assertFormatEqual(contents_spc, fs(contents_tab))
400
401     def test_report_verbose(self) -> None:
402         report = Report(verbose=True)
403         out_lines = []
404         err_lines = []
405
406         def out(msg: str, **kwargs: Any) -> None:
407             out_lines.append(msg)
408
409         def err(msg: str, **kwargs: Any) -> None:
410             err_lines.append(msg)
411
412         with patch("black.output._out", out), patch("black.output._err", err):
413             report.done(Path("f1"), black.Changed.NO)
414             self.assertEqual(len(out_lines), 1)
415             self.assertEqual(len(err_lines), 0)
416             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
417             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
418             self.assertEqual(report.return_code, 0)
419             report.done(Path("f2"), black.Changed.YES)
420             self.assertEqual(len(out_lines), 2)
421             self.assertEqual(len(err_lines), 0)
422             self.assertEqual(out_lines[-1], "reformatted f2")
423             self.assertEqual(
424                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
425             )
426             report.done(Path("f3"), black.Changed.CACHED)
427             self.assertEqual(len(out_lines), 3)
428             self.assertEqual(len(err_lines), 0)
429             self.assertEqual(
430                 out_lines[-1], "f3 wasn't modified on disk since last run."
431             )
432             self.assertEqual(
433                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
434             )
435             self.assertEqual(report.return_code, 0)
436             report.check = True
437             self.assertEqual(report.return_code, 1)
438             report.check = False
439             report.failed(Path("e1"), "boom")
440             self.assertEqual(len(out_lines), 3)
441             self.assertEqual(len(err_lines), 1)
442             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
446                 " reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.done(Path("f3"), black.Changed.YES)
450             self.assertEqual(len(out_lines), 4)
451             self.assertEqual(len(err_lines), 1)
452             self.assertEqual(out_lines[-1], "reformatted f3")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
456                 " reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.failed(Path("e2"), "boom")
460             self.assertEqual(len(out_lines), 4)
461             self.assertEqual(len(err_lines), 2)
462             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
463             self.assertEqual(
464                 unstyle(str(report)),
465                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
466                 " reformat.",
467             )
468             self.assertEqual(report.return_code, 123)
469             report.path_ignored(Path("wat"), "no match")
470             self.assertEqual(len(out_lines), 5)
471             self.assertEqual(len(err_lines), 2)
472             self.assertEqual(out_lines[-1], "wat ignored: no match")
473             self.assertEqual(
474                 unstyle(str(report)),
475                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
476                 " reformat.",
477             )
478             self.assertEqual(report.return_code, 123)
479             report.done(Path("f4"), black.Changed.NO)
480             self.assertEqual(len(out_lines), 6)
481             self.assertEqual(len(err_lines), 2)
482             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
483             self.assertEqual(
484                 unstyle(str(report)),
485                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
486                 " reformat.",
487             )
488             self.assertEqual(report.return_code, 123)
489             report.check = True
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
493                 " would fail to reformat.",
494             )
495             report.check = False
496             report.diff = True
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
500                 " would fail to reformat.",
501             )
502
503     def test_report_quiet(self) -> None:
504         report = Report(quiet=True)
505         out_lines = []
506         err_lines = []
507
508         def out(msg: str, **kwargs: Any) -> None:
509             out_lines.append(msg)
510
511         def err(msg: str, **kwargs: Any) -> None:
512             err_lines.append(msg)
513
514         with patch("black.output._out", out), patch("black.output._err", err):
515             report.done(Path("f1"), black.Changed.NO)
516             self.assertEqual(len(out_lines), 0)
517             self.assertEqual(len(err_lines), 0)
518             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
519             self.assertEqual(report.return_code, 0)
520             report.done(Path("f2"), black.Changed.YES)
521             self.assertEqual(len(out_lines), 0)
522             self.assertEqual(len(err_lines), 0)
523             self.assertEqual(
524                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
525             )
526             report.done(Path("f3"), black.Changed.CACHED)
527             self.assertEqual(len(out_lines), 0)
528             self.assertEqual(len(err_lines), 0)
529             self.assertEqual(
530                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
531             )
532             self.assertEqual(report.return_code, 0)
533             report.check = True
534             self.assertEqual(report.return_code, 1)
535             report.check = False
536             report.failed(Path("e1"), "boom")
537             self.assertEqual(len(out_lines), 0)
538             self.assertEqual(len(err_lines), 1)
539             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
540             self.assertEqual(
541                 unstyle(str(report)),
542                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
543                 " reformat.",
544             )
545             self.assertEqual(report.return_code, 123)
546             report.done(Path("f3"), black.Changed.YES)
547             self.assertEqual(len(out_lines), 0)
548             self.assertEqual(len(err_lines), 1)
549             self.assertEqual(
550                 unstyle(str(report)),
551                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
552                 " reformat.",
553             )
554             self.assertEqual(report.return_code, 123)
555             report.failed(Path("e2"), "boom")
556             self.assertEqual(len(out_lines), 0)
557             self.assertEqual(len(err_lines), 2)
558             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
562                 " reformat.",
563             )
564             self.assertEqual(report.return_code, 123)
565             report.path_ignored(Path("wat"), "no match")
566             self.assertEqual(len(out_lines), 0)
567             self.assertEqual(len(err_lines), 2)
568             self.assertEqual(
569                 unstyle(str(report)),
570                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
571                 " reformat.",
572             )
573             self.assertEqual(report.return_code, 123)
574             report.done(Path("f4"), black.Changed.NO)
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 2)
577             self.assertEqual(
578                 unstyle(str(report)),
579                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
580                 " reformat.",
581             )
582             self.assertEqual(report.return_code, 123)
583             report.check = True
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
587                 " would fail to reformat.",
588             )
589             report.check = False
590             report.diff = True
591             self.assertEqual(
592                 unstyle(str(report)),
593                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
594                 " would fail to reformat.",
595             )
596
597     def test_report_normal(self) -> None:
598         report = black.Report()
599         out_lines = []
600         err_lines = []
601
602         def out(msg: str, **kwargs: Any) -> None:
603             out_lines.append(msg)
604
605         def err(msg: str, **kwargs: Any) -> None:
606             err_lines.append(msg)
607
608         with patch("black.output._out", out), patch("black.output._err", err):
609             report.done(Path("f1"), black.Changed.NO)
610             self.assertEqual(len(out_lines), 0)
611             self.assertEqual(len(err_lines), 0)
612             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
613             self.assertEqual(report.return_code, 0)
614             report.done(Path("f2"), black.Changed.YES)
615             self.assertEqual(len(out_lines), 1)
616             self.assertEqual(len(err_lines), 0)
617             self.assertEqual(out_lines[-1], "reformatted f2")
618             self.assertEqual(
619                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
620             )
621             report.done(Path("f3"), black.Changed.CACHED)
622             self.assertEqual(len(out_lines), 1)
623             self.assertEqual(len(err_lines), 0)
624             self.assertEqual(out_lines[-1], "reformatted f2")
625             self.assertEqual(
626                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
627             )
628             self.assertEqual(report.return_code, 0)
629             report.check = True
630             self.assertEqual(report.return_code, 1)
631             report.check = False
632             report.failed(Path("e1"), "boom")
633             self.assertEqual(len(out_lines), 1)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
639                 " reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.done(Path("f3"), black.Changed.YES)
643             self.assertEqual(len(out_lines), 2)
644             self.assertEqual(len(err_lines), 1)
645             self.assertEqual(out_lines[-1], "reformatted f3")
646             self.assertEqual(
647                 unstyle(str(report)),
648                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
649                 " reformat.",
650             )
651             self.assertEqual(report.return_code, 123)
652             report.failed(Path("e2"), "boom")
653             self.assertEqual(len(out_lines), 2)
654             self.assertEqual(len(err_lines), 2)
655             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
656             self.assertEqual(
657                 unstyle(str(report)),
658                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
659                 " reformat.",
660             )
661             self.assertEqual(report.return_code, 123)
662             report.path_ignored(Path("wat"), "no match")
663             self.assertEqual(len(out_lines), 2)
664             self.assertEqual(len(err_lines), 2)
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
668                 " reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.done(Path("f4"), black.Changed.NO)
672             self.assertEqual(len(out_lines), 2)
673             self.assertEqual(len(err_lines), 2)
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
677                 " reformat.",
678             )
679             self.assertEqual(report.return_code, 123)
680             report.check = True
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
684                 " would fail to reformat.",
685             )
686             report.check = False
687             report.diff = True
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
691                 " would fail to reformat.",
692             )
693
694     def test_lib2to3_parse(self) -> None:
695         with self.assertRaises(black.InvalidInput):
696             black.lib2to3_parse("invalid syntax")
697
698         straddling = "x + y"
699         black.lib2to3_parse(straddling)
700         black.lib2to3_parse(straddling, {TargetVersion.PY36})
701
702         py2_only = "print x"
703         with self.assertRaises(black.InvalidInput):
704             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
705
706         py3_only = "exec(x, end=y)"
707         black.lib2to3_parse(py3_only)
708         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
709
710     def test_get_features_used_decorator(self) -> None:
711         # Test the feature detection of new decorator syntax
712         # since this makes some test cases of test_get_features_used()
713         # fails if it fails, this is tested first so that a useful case
714         # is identified
715         simples, relaxed = read_data("decorators")
716         # skip explanation comments at the top of the file
717         for simple_test in simples.split("##")[1:]:
718             node = black.lib2to3_parse(simple_test)
719             decorator = str(node.children[0].children[0]).strip()
720             self.assertNotIn(
721                 Feature.RELAXED_DECORATORS,
722                 black.get_features_used(node),
723                 msg=(
724                     f"decorator '{decorator}' follows python<=3.8 syntax"
725                     "but is detected as 3.9+"
726                     # f"The full node is\n{node!r}"
727                 ),
728             )
729         # skip the '# output' comment at the top of the output part
730         for relaxed_test in relaxed.split("##")[1:]:
731             node = black.lib2to3_parse(relaxed_test)
732             decorator = str(node.children[0].children[0]).strip()
733             self.assertIn(
734                 Feature.RELAXED_DECORATORS,
735                 black.get_features_used(node),
736                 msg=(
737                     f"decorator '{decorator}' uses python3.9+ syntax"
738                     "but is detected as python<=3.8"
739                     # f"The full node is\n{node!r}"
740                 ),
741             )
742
743     def test_get_features_used(self) -> None:
744         node = black.lib2to3_parse("def f(*, arg): ...\n")
745         self.assertEqual(black.get_features_used(node), set())
746         node = black.lib2to3_parse("def f(*, arg,): ...\n")
747         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
748         node = black.lib2to3_parse("f(*arg,)\n")
749         self.assertEqual(
750             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
751         )
752         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
753         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
754         node = black.lib2to3_parse("123_456\n")
755         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
756         node = black.lib2to3_parse("123456\n")
757         self.assertEqual(black.get_features_used(node), set())
758         source, expected = read_data("function")
759         node = black.lib2to3_parse(source)
760         expected_features = {
761             Feature.TRAILING_COMMA_IN_CALL,
762             Feature.TRAILING_COMMA_IN_DEF,
763             Feature.F_STRINGS,
764         }
765         self.assertEqual(black.get_features_used(node), expected_features)
766         node = black.lib2to3_parse(expected)
767         self.assertEqual(black.get_features_used(node), expected_features)
768         source, expected = read_data("expression")
769         node = black.lib2to3_parse(source)
770         self.assertEqual(black.get_features_used(node), set())
771         node = black.lib2to3_parse(expected)
772         self.assertEqual(black.get_features_used(node), set())
773         node = black.lib2to3_parse("lambda a, /, b: ...")
774         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
775         node = black.lib2to3_parse("def fn(a, /, b): ...")
776         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
777         node = black.lib2to3_parse("def fn(): yield a, b")
778         self.assertEqual(black.get_features_used(node), set())
779         node = black.lib2to3_parse("def fn(): return a, b")
780         self.assertEqual(black.get_features_used(node), set())
781         node = black.lib2to3_parse("def fn(): yield *b, c")
782         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
783         node = black.lib2to3_parse("def fn(): return a, *b, c")
784         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
785         node = black.lib2to3_parse("x = a, *b, c")
786         self.assertEqual(black.get_features_used(node), set())
787         node = black.lib2to3_parse("x: Any = regular")
788         self.assertEqual(black.get_features_used(node), set())
789         node = black.lib2to3_parse("x: Any = (regular, regular)")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
794         self.assertEqual(
795             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
796         )
797
798     def test_get_features_used_for_future_flags(self) -> None:
799         for src, features in [
800             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
801             (
802                 "from __future__ import (other, annotations)",
803                 {Feature.FUTURE_ANNOTATIONS},
804             ),
805             ("a = 1 + 2\nfrom something import annotations", set()),
806             ("from __future__ import x, y", set()),
807         ]:
808             with self.subTest(src=src, features=features):
809                 node = black.lib2to3_parse(src)
810                 future_imports = black.get_future_imports(node)
811                 self.assertEqual(
812                     black.get_features_used(node, future_imports=future_imports),
813                     features,
814                 )
815
816     def test_get_future_imports(self) -> None:
817         node = black.lib2to3_parse("\n")
818         self.assertEqual(set(), black.get_future_imports(node))
819         node = black.lib2to3_parse("from __future__ import black\n")
820         self.assertEqual({"black"}, black.get_future_imports(node))
821         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
822         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
823         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
824         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
825         node = black.lib2to3_parse(
826             "from __future__ import multiple\nfrom __future__ import imports\n"
827         )
828         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
829         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
830         self.assertEqual({"black"}, black.get_future_imports(node))
831         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
832         self.assertEqual({"black"}, black.get_future_imports(node))
833         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from some.module import black\n")
836         self.assertEqual(set(), black.get_future_imports(node))
837         node = black.lib2to3_parse(
838             "from __future__ import unicode_literals as _unicode_literals"
839         )
840         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import unicode_literals as _lol, print"
843         )
844         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
845
846     @pytest.mark.incompatible_with_mypyc
847     def test_debug_visitor(self) -> None:
848         source, _ = read_data("debug_visitor.py")
849         expected, _ = read_data("debug_visitor.out")
850         out_lines = []
851         err_lines = []
852
853         def out(msg: str, **kwargs: Any) -> None:
854             out_lines.append(msg)
855
856         def err(msg: str, **kwargs: Any) -> None:
857             err_lines.append(msg)
858
859         with patch("black.debug.out", out):
860             DebugVisitor.show(source)
861         actual = "\n".join(out_lines) + "\n"
862         log_name = ""
863         if expected != actual:
864             log_name = black.dump_to_file(*out_lines)
865         self.assertEqual(
866             expected,
867             actual,
868             f"AST print out is different. Actual version dumped to {log_name}",
869         )
870
871     def test_format_file_contents(self) -> None:
872         empty = ""
873         mode = DEFAULT_MODE
874         with self.assertRaises(black.NothingChanged):
875             black.format_file_contents(empty, mode=mode, fast=False)
876         just_nl = "\n"
877         with self.assertRaises(black.NothingChanged):
878             black.format_file_contents(just_nl, mode=mode, fast=False)
879         same = "j = [1, 2, 3]\n"
880         with self.assertRaises(black.NothingChanged):
881             black.format_file_contents(same, mode=mode, fast=False)
882         different = "j = [1,2,3]"
883         expected = same
884         actual = black.format_file_contents(different, mode=mode, fast=False)
885         self.assertEqual(expected, actual)
886         invalid = "return if you can"
887         with self.assertRaises(black.InvalidInput) as e:
888             black.format_file_contents(invalid, mode=mode, fast=False)
889         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
890
891     def test_endmarker(self) -> None:
892         n = black.lib2to3_parse("\n")
893         self.assertEqual(n.type, black.syms.file_input)
894         self.assertEqual(len(n.children), 1)
895         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
896
897     @pytest.mark.incompatible_with_mypyc
898     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
899     def test_assertFormatEqual(self) -> None:
900         out_lines = []
901         err_lines = []
902
903         def out(msg: str, **kwargs: Any) -> None:
904             out_lines.append(msg)
905
906         def err(msg: str, **kwargs: Any) -> None:
907             err_lines.append(msg)
908
909         with patch("black.output._out", out), patch("black.output._err", err):
910             with self.assertRaises(AssertionError):
911                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
912
913         out_str = "".join(out_lines)
914         self.assertIn("Expected tree:", out_str)
915         self.assertIn("Actual tree:", out_str)
916         self.assertEqual("".join(err_lines), "")
917
918     @event_loop()
919     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
920     def test_works_in_mono_process_only_environment(self) -> None:
921         with cache_dir() as workspace:
922             for f in [
923                 (workspace / "one.py").resolve(),
924                 (workspace / "two.py").resolve(),
925             ]:
926                 f.write_text('print("hello")\n')
927             self.invokeBlack([str(workspace)])
928
929     @event_loop()
930     def test_check_diff_use_together(self) -> None:
931         with cache_dir():
932             # Files which will be reformatted.
933             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
934             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
935             # Files which will not be reformatted.
936             src2 = (THIS_DIR / "data" / "composition.py").resolve()
937             self.invokeBlack([str(src2), "--diff", "--check"])
938             # Multi file command.
939             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
940
941     def test_no_src_fails(self) -> None:
942         with cache_dir():
943             self.invokeBlack([], exit_code=1)
944
945     def test_src_and_code_fails(self) -> None:
946         with cache_dir():
947             self.invokeBlack([".", "-c", "0"], exit_code=1)
948
949     def test_broken_symlink(self) -> None:
950         with cache_dir() as workspace:
951             symlink = workspace / "broken_link.py"
952             try:
953                 symlink.symlink_to("nonexistent.py")
954             except (OSError, NotImplementedError) as e:
955                 self.skipTest(f"Can't create symlinks: {e}")
956             self.invokeBlack([str(workspace.resolve())])
957
958     def test_single_file_force_pyi(self) -> None:
959         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
960         contents, expected = read_data("force_pyi")
961         with cache_dir() as workspace:
962             path = (workspace / "file.py").resolve()
963             with open(path, "w") as fh:
964                 fh.write(contents)
965             self.invokeBlack([str(path), "--pyi"])
966             with open(path, "r") as fh:
967                 actual = fh.read()
968             # verify cache with --pyi is separate
969             pyi_cache = black.read_cache(pyi_mode)
970             self.assertIn(str(path), pyi_cache)
971             normal_cache = black.read_cache(DEFAULT_MODE)
972             self.assertNotIn(str(path), normal_cache)
973         self.assertFormatEqual(expected, actual)
974         black.assert_equivalent(contents, actual)
975         black.assert_stable(contents, actual, pyi_mode)
976
977     @event_loop()
978     def test_multi_file_force_pyi(self) -> None:
979         reg_mode = DEFAULT_MODE
980         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
981         contents, expected = read_data("force_pyi")
982         with cache_dir() as workspace:
983             paths = [
984                 (workspace / "file1.py").resolve(),
985                 (workspace / "file2.py").resolve(),
986             ]
987             for path in paths:
988                 with open(path, "w") as fh:
989                     fh.write(contents)
990             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
991             for path in paths:
992                 with open(path, "r") as fh:
993                     actual = fh.read()
994                 self.assertEqual(actual, expected)
995             # verify cache with --pyi is separate
996             pyi_cache = black.read_cache(pyi_mode)
997             normal_cache = black.read_cache(reg_mode)
998             for path in paths:
999                 self.assertIn(str(path), pyi_cache)
1000                 self.assertNotIn(str(path), normal_cache)
1001
1002     def test_pipe_force_pyi(self) -> None:
1003         source, expected = read_data("force_pyi")
1004         result = CliRunner().invoke(
1005             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1006         )
1007         self.assertEqual(result.exit_code, 0)
1008         actual = result.output
1009         self.assertFormatEqual(actual, expected)
1010
1011     def test_single_file_force_py36(self) -> None:
1012         reg_mode = DEFAULT_MODE
1013         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1014         source, expected = read_data("force_py36")
1015         with cache_dir() as workspace:
1016             path = (workspace / "file.py").resolve()
1017             with open(path, "w") as fh:
1018                 fh.write(source)
1019             self.invokeBlack([str(path), *PY36_ARGS])
1020             with open(path, "r") as fh:
1021                 actual = fh.read()
1022             # verify cache with --target-version is separate
1023             py36_cache = black.read_cache(py36_mode)
1024             self.assertIn(str(path), py36_cache)
1025             normal_cache = black.read_cache(reg_mode)
1026             self.assertNotIn(str(path), normal_cache)
1027         self.assertEqual(actual, expected)
1028
1029     @event_loop()
1030     def test_multi_file_force_py36(self) -> None:
1031         reg_mode = DEFAULT_MODE
1032         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1033         source, expected = read_data("force_py36")
1034         with cache_dir() as workspace:
1035             paths = [
1036                 (workspace / "file1.py").resolve(),
1037                 (workspace / "file2.py").resolve(),
1038             ]
1039             for path in paths:
1040                 with open(path, "w") as fh:
1041                     fh.write(source)
1042             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1043             for path in paths:
1044                 with open(path, "r") as fh:
1045                     actual = fh.read()
1046                 self.assertEqual(actual, expected)
1047             # verify cache with --target-version is separate
1048             pyi_cache = black.read_cache(py36_mode)
1049             normal_cache = black.read_cache(reg_mode)
1050             for path in paths:
1051                 self.assertIn(str(path), pyi_cache)
1052                 self.assertNotIn(str(path), normal_cache)
1053
1054     def test_pipe_force_py36(self) -> None:
1055         source, expected = read_data("force_py36")
1056         result = CliRunner().invoke(
1057             black.main,
1058             ["-", "-q", "--target-version=py36"],
1059             input=BytesIO(source.encode("utf8")),
1060         )
1061         self.assertEqual(result.exit_code, 0)
1062         actual = result.output
1063         self.assertFormatEqual(actual, expected)
1064
1065     @pytest.mark.incompatible_with_mypyc
1066     def test_reformat_one_with_stdin(self) -> None:
1067         with patch(
1068             "black.format_stdin_to_stdout",
1069             return_value=lambda *args, **kwargs: black.Changed.YES,
1070         ) as fsts:
1071             report = MagicMock()
1072             path = Path("-")
1073             black.reformat_one(
1074                 path,
1075                 fast=True,
1076                 write_back=black.WriteBack.YES,
1077                 mode=DEFAULT_MODE,
1078                 report=report,
1079             )
1080             fsts.assert_called_once()
1081             report.done.assert_called_with(path, black.Changed.YES)
1082
1083     @pytest.mark.incompatible_with_mypyc
1084     def test_reformat_one_with_stdin_filename(self) -> None:
1085         with patch(
1086             "black.format_stdin_to_stdout",
1087             return_value=lambda *args, **kwargs: black.Changed.YES,
1088         ) as fsts:
1089             report = MagicMock()
1090             p = "foo.py"
1091             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1092             expected = Path(p)
1093             black.reformat_one(
1094                 path,
1095                 fast=True,
1096                 write_back=black.WriteBack.YES,
1097                 mode=DEFAULT_MODE,
1098                 report=report,
1099             )
1100             fsts.assert_called_once_with(
1101                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1102             )
1103             # __BLACK_STDIN_FILENAME__ should have been stripped
1104             report.done.assert_called_with(expected, black.Changed.YES)
1105
1106     @pytest.mark.incompatible_with_mypyc
1107     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1108         with patch(
1109             "black.format_stdin_to_stdout",
1110             return_value=lambda *args, **kwargs: black.Changed.YES,
1111         ) as fsts:
1112             report = MagicMock()
1113             p = "foo.pyi"
1114             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1115             expected = Path(p)
1116             black.reformat_one(
1117                 path,
1118                 fast=True,
1119                 write_back=black.WriteBack.YES,
1120                 mode=DEFAULT_MODE,
1121                 report=report,
1122             )
1123             fsts.assert_called_once_with(
1124                 fast=True,
1125                 write_back=black.WriteBack.YES,
1126                 mode=replace(DEFAULT_MODE, is_pyi=True),
1127             )
1128             # __BLACK_STDIN_FILENAME__ should have been stripped
1129             report.done.assert_called_with(expected, black.Changed.YES)
1130
1131     @pytest.mark.incompatible_with_mypyc
1132     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1133         with patch(
1134             "black.format_stdin_to_stdout",
1135             return_value=lambda *args, **kwargs: black.Changed.YES,
1136         ) as fsts:
1137             report = MagicMock()
1138             p = "foo.ipynb"
1139             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1140             expected = Path(p)
1141             black.reformat_one(
1142                 path,
1143                 fast=True,
1144                 write_back=black.WriteBack.YES,
1145                 mode=DEFAULT_MODE,
1146                 report=report,
1147             )
1148             fsts.assert_called_once_with(
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1152             )
1153             # __BLACK_STDIN_FILENAME__ should have been stripped
1154             report.done.assert_called_with(expected, black.Changed.YES)
1155
1156     @pytest.mark.incompatible_with_mypyc
1157     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1158         with patch(
1159             "black.format_stdin_to_stdout",
1160             return_value=lambda *args, **kwargs: black.Changed.YES,
1161         ) as fsts:
1162             report = MagicMock()
1163             # Even with an existing file, since we are forcing stdin, black
1164             # should output to stdout and not modify the file inplace
1165             p = Path(str(THIS_DIR / "data/collections.py"))
1166             # Make sure is_file actually returns True
1167             self.assertTrue(p.is_file())
1168             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1169             expected = Path(p)
1170             black.reformat_one(
1171                 path,
1172                 fast=True,
1173                 write_back=black.WriteBack.YES,
1174                 mode=DEFAULT_MODE,
1175                 report=report,
1176             )
1177             fsts.assert_called_once()
1178             # __BLACK_STDIN_FILENAME__ should have been stripped
1179             report.done.assert_called_with(expected, black.Changed.YES)
1180
1181     def test_reformat_one_with_stdin_empty(self) -> None:
1182         output = io.StringIO()
1183         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1184             try:
1185                 black.format_stdin_to_stdout(
1186                     fast=True,
1187                     content="",
1188                     write_back=black.WriteBack.YES,
1189                     mode=DEFAULT_MODE,
1190                 )
1191             except io.UnsupportedOperation:
1192                 pass  # StringIO does not support detach
1193             assert output.getvalue() == ""
1194
1195     def test_invalid_cli_regex(self) -> None:
1196         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1197             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1198
1199     def test_required_version_matches_version(self) -> None:
1200         self.invokeBlack(
1201             ["--required-version", black.__version__, "-c", "0"],
1202             exit_code=0,
1203             ignore_config=True,
1204         )
1205
1206     def test_required_version_matches_partial_version(self) -> None:
1207         self.invokeBlack(
1208             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1209             exit_code=0,
1210             ignore_config=True,
1211         )
1212
1213     def test_required_version_does_not_match_on_minor_version(self) -> None:
1214         self.invokeBlack(
1215             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1216             exit_code=1,
1217             ignore_config=True,
1218         )
1219
1220     def test_required_version_does_not_match_version(self) -> None:
1221         result = BlackRunner().invoke(
1222             black.main,
1223             ["--required-version", "20.99b", "-c", "0"],
1224         )
1225         self.assertEqual(result.exit_code, 1)
1226         self.assertIn("required version", result.stderr)
1227
1228     def test_preserves_line_endings(self) -> None:
1229         with TemporaryDirectory() as workspace:
1230             test_file = Path(workspace) / "test.py"
1231             for nl in ["\n", "\r\n"]:
1232                 contents = nl.join(["def f(  ):", "    pass"])
1233                 test_file.write_bytes(contents.encode())
1234                 ff(test_file, write_back=black.WriteBack.YES)
1235                 updated_contents: bytes = test_file.read_bytes()
1236                 self.assertIn(nl.encode(), updated_contents)
1237                 if nl == "\n":
1238                     self.assertNotIn(b"\r\n", updated_contents)
1239
1240     def test_preserves_line_endings_via_stdin(self) -> None:
1241         for nl in ["\n", "\r\n"]:
1242             contents = nl.join(["def f(  ):", "    pass"])
1243             runner = BlackRunner()
1244             result = runner.invoke(
1245                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1246             )
1247             self.assertEqual(result.exit_code, 0)
1248             output = result.stdout_bytes
1249             self.assertIn(nl.encode("utf8"), output)
1250             if nl == "\n":
1251                 self.assertNotIn(b"\r\n", output)
1252
1253     def test_assert_equivalent_different_asts(self) -> None:
1254         with self.assertRaises(AssertionError):
1255             black.assert_equivalent("{}", "None")
1256
1257     def test_shhh_click(self) -> None:
1258         try:
1259             from click import _unicodefun  # type: ignore
1260         except ImportError:
1261             self.skipTest("Incompatible Click version")
1262
1263         if not hasattr(_unicodefun, "_verify_python_env"):
1264             self.skipTest("Incompatible Click version")
1265
1266         # First, let's see if Click is crashing with a preferred ASCII charset.
1267         with patch("locale.getpreferredencoding") as gpe:
1268             gpe.return_value = "ASCII"
1269             with self.assertRaises(RuntimeError):
1270                 _unicodefun._verify_python_env()
1271         # Now, let's silence Click...
1272         black.patch_click()
1273         # ...and confirm it's silent.
1274         with patch("locale.getpreferredencoding") as gpe:
1275             gpe.return_value = "ASCII"
1276             try:
1277                 _unicodefun._verify_python_env()
1278             except RuntimeError as re:
1279                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1280
1281     def test_root_logger_not_used_directly(self) -> None:
1282         def fail(*args: Any, **kwargs: Any) -> None:
1283             self.fail("Record created with root logger")
1284
1285         with patch.multiple(
1286             logging.root,
1287             debug=fail,
1288             info=fail,
1289             warning=fail,
1290             error=fail,
1291             critical=fail,
1292             log=fail,
1293         ):
1294             ff(THIS_DIR / "util.py")
1295
1296     def test_invalid_config_return_code(self) -> None:
1297         tmp_file = Path(black.dump_to_file())
1298         try:
1299             tmp_config = Path(black.dump_to_file())
1300             tmp_config.unlink()
1301             args = ["--config", str(tmp_config), str(tmp_file)]
1302             self.invokeBlack(args, exit_code=2, ignore_config=False)
1303         finally:
1304             tmp_file.unlink()
1305
1306     def test_parse_pyproject_toml(self) -> None:
1307         test_toml_file = THIS_DIR / "test.toml"
1308         config = black.parse_pyproject_toml(str(test_toml_file))
1309         self.assertEqual(config["verbose"], 1)
1310         self.assertEqual(config["check"], "no")
1311         self.assertEqual(config["diff"], "y")
1312         self.assertEqual(config["color"], True)
1313         self.assertEqual(config["line_length"], 79)
1314         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1315         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1316         self.assertEqual(config["exclude"], r"\.pyi?$")
1317         self.assertEqual(config["include"], r"\.py?$")
1318
1319     def test_read_pyproject_toml(self) -> None:
1320         test_toml_file = THIS_DIR / "test.toml"
1321         fake_ctx = FakeContext()
1322         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1323         config = fake_ctx.default_map
1324         self.assertEqual(config["verbose"], "1")
1325         self.assertEqual(config["check"], "no")
1326         self.assertEqual(config["diff"], "y")
1327         self.assertEqual(config["color"], "True")
1328         self.assertEqual(config["line_length"], "79")
1329         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1330         self.assertEqual(config["exclude"], r"\.pyi?$")
1331         self.assertEqual(config["include"], r"\.py?$")
1332
1333     @pytest.mark.incompatible_with_mypyc
1334     def test_find_project_root(self) -> None:
1335         with TemporaryDirectory() as workspace:
1336             root = Path(workspace)
1337             test_dir = root / "test"
1338             test_dir.mkdir()
1339
1340             src_dir = root / "src"
1341             src_dir.mkdir()
1342
1343             root_pyproject = root / "pyproject.toml"
1344             root_pyproject.touch()
1345             src_pyproject = src_dir / "pyproject.toml"
1346             src_pyproject.touch()
1347             src_python = src_dir / "foo.py"
1348             src_python.touch()
1349
1350             self.assertEqual(
1351                 black.find_project_root((src_dir, test_dir)),
1352                 (root.resolve(), "pyproject.toml"),
1353             )
1354             self.assertEqual(
1355                 black.find_project_root((src_dir,)),
1356                 (src_dir.resolve(), "pyproject.toml"),
1357             )
1358             self.assertEqual(
1359                 black.find_project_root((src_python,)),
1360                 (src_dir.resolve(), "pyproject.toml"),
1361             )
1362
1363     @patch(
1364         "black.files.find_user_pyproject_toml",
1365     )
1366     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1367         find_user_pyproject_toml.side_effect = RuntimeError()
1368
1369         with redirect_stderr(io.StringIO()) as stderr:
1370             result = black.files.find_pyproject_toml(
1371                 path_search_start=(str(Path.cwd().root),)
1372             )
1373
1374         assert result is None
1375         err = stderr.getvalue()
1376         assert "Ignoring user configuration" in err
1377
1378     @patch(
1379         "black.files.find_user_pyproject_toml",
1380         black.files.find_user_pyproject_toml.__wrapped__,
1381     )
1382     def test_find_user_pyproject_toml_linux(self) -> None:
1383         if system() == "Windows":
1384             return
1385
1386         # Test if XDG_CONFIG_HOME is checked
1387         with TemporaryDirectory() as workspace:
1388             tmp_user_config = Path(workspace) / "black"
1389             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1390                 self.assertEqual(
1391                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1392                 )
1393
1394         # Test fallback for XDG_CONFIG_HOME
1395         with patch.dict("os.environ"):
1396             os.environ.pop("XDG_CONFIG_HOME", None)
1397             fallback_user_config = Path("~/.config").expanduser() / "black"
1398             self.assertEqual(
1399                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1400             )
1401
1402     def test_find_user_pyproject_toml_windows(self) -> None:
1403         if system() != "Windows":
1404             return
1405
1406         user_config_path = Path.home() / ".black"
1407         self.assertEqual(
1408             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1409         )
1410
1411     def test_bpo_33660_workaround(self) -> None:
1412         if system() == "Windows":
1413             return
1414
1415         # https://bugs.python.org/issue33660
1416         root = Path("/")
1417         with change_directory(root):
1418             path = Path("workspace") / "project"
1419             report = black.Report(verbose=True)
1420             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1421             self.assertEqual(normalized_path, "workspace/project")
1422
1423     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1424         if system() != "Windows":
1425             return
1426
1427         with TemporaryDirectory() as workspace:
1428             root = Path(workspace)
1429             junction_dir = root / "junction"
1430             junction_target_outside_of_root = root / ".."
1431             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1432
1433             report = black.Report(verbose=True)
1434             normalized_path = black.normalize_path_maybe_ignore(
1435                 junction_dir, root, report
1436             )
1437             # Manually delete for Python < 3.8
1438             os.system(f"rmdir {junction_dir}")
1439
1440             self.assertEqual(normalized_path, None)
1441
1442     def test_newline_comment_interaction(self) -> None:
1443         source = "class A:\\\r\n# type: ignore\n pass\n"
1444         output = black.format_str(source, mode=DEFAULT_MODE)
1445         black.assert_stable(source, output, mode=DEFAULT_MODE)
1446
1447     def test_bpo_2142_workaround(self) -> None:
1448
1449         # https://bugs.python.org/issue2142
1450
1451         source, _ = read_data("missing_final_newline.py")
1452         # read_data adds a trailing newline
1453         source = source.rstrip()
1454         expected, _ = read_data("missing_final_newline.diff")
1455         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1456         diff_header = re.compile(
1457             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1458             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1459         )
1460         try:
1461             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1462             self.assertEqual(result.exit_code, 0)
1463         finally:
1464             os.unlink(tmp_file)
1465         actual = result.output
1466         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1467         self.assertEqual(actual, expected)
1468
1469     @staticmethod
1470     def compare_results(
1471         result: click.testing.Result, expected_value: str, expected_exit_code: int
1472     ) -> None:
1473         """Helper method to test the value and exit code of a click Result."""
1474         assert (
1475             result.output == expected_value
1476         ), "The output did not match the expected value."
1477         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1478
1479     def test_code_option(self) -> None:
1480         """Test the code option with no changes."""
1481         code = 'print("Hello world")\n'
1482         args = ["--code", code]
1483         result = CliRunner().invoke(black.main, args)
1484
1485         self.compare_results(result, code, 0)
1486
1487     def test_code_option_changed(self) -> None:
1488         """Test the code option when changes are required."""
1489         code = "print('hello world')"
1490         formatted = black.format_str(code, mode=DEFAULT_MODE)
1491
1492         args = ["--code", code]
1493         result = CliRunner().invoke(black.main, args)
1494
1495         self.compare_results(result, formatted, 0)
1496
1497     def test_code_option_check(self) -> None:
1498         """Test the code option when check is passed."""
1499         args = ["--check", "--code", 'print("Hello world")\n']
1500         result = CliRunner().invoke(black.main, args)
1501         self.compare_results(result, "", 0)
1502
1503     def test_code_option_check_changed(self) -> None:
1504         """Test the code option when changes are required, and check is passed."""
1505         args = ["--check", "--code", "print('hello world')"]
1506         result = CliRunner().invoke(black.main, args)
1507         self.compare_results(result, "", 1)
1508
1509     def test_code_option_diff(self) -> None:
1510         """Test the code option when diff is passed."""
1511         code = "print('hello world')"
1512         formatted = black.format_str(code, mode=DEFAULT_MODE)
1513         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1514
1515         args = ["--diff", "--code", code]
1516         result = CliRunner().invoke(black.main, args)
1517
1518         # Remove time from diff
1519         output = DIFF_TIME.sub("", result.output)
1520
1521         assert output == result_diff, "The output did not match the expected value."
1522         assert result.exit_code == 0, "The exit code is incorrect."
1523
1524     def test_code_option_color_diff(self) -> None:
1525         """Test the code option when color and diff are passed."""
1526         code = "print('hello world')"
1527         formatted = black.format_str(code, mode=DEFAULT_MODE)
1528
1529         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1530         result_diff = color_diff(result_diff)
1531
1532         args = ["--diff", "--color", "--code", code]
1533         result = CliRunner().invoke(black.main, args)
1534
1535         # Remove time from diff
1536         output = DIFF_TIME.sub("", result.output)
1537
1538         assert output == result_diff, "The output did not match the expected value."
1539         assert result.exit_code == 0, "The exit code is incorrect."
1540
1541     @pytest.mark.incompatible_with_mypyc
1542     def test_code_option_safe(self) -> None:
1543         """Test that the code option throws an error when the sanity checks fail."""
1544         # Patch black.assert_equivalent to ensure the sanity checks fail
1545         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1546             code = 'print("Hello world")'
1547             error_msg = f"{code}\nerror: cannot format <string>: \n"
1548
1549             args = ["--safe", "--code", code]
1550             result = CliRunner().invoke(black.main, args)
1551
1552             self.compare_results(result, error_msg, 123)
1553
1554     def test_code_option_fast(self) -> None:
1555         """Test that the code option ignores errors when the sanity checks fail."""
1556         # Patch black.assert_equivalent to ensure the sanity checks fail
1557         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1558             code = 'print("Hello world")'
1559             formatted = black.format_str(code, mode=DEFAULT_MODE)
1560
1561             args = ["--fast", "--code", code]
1562             result = CliRunner().invoke(black.main, args)
1563
1564             self.compare_results(result, formatted, 0)
1565
1566     @pytest.mark.incompatible_with_mypyc
1567     def test_code_option_config(self) -> None:
1568         """
1569         Test that the code option finds the pyproject.toml in the current directory.
1570         """
1571         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1572             args = ["--code", "print"]
1573             # This is the only directory known to contain a pyproject.toml
1574             with change_directory(PROJECT_ROOT):
1575                 CliRunner().invoke(black.main, args)
1576                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1577
1578             assert (
1579                 len(parse.mock_calls) >= 1
1580             ), "Expected config parse to be called with the current directory."
1581
1582             _, call_args, _ = parse.mock_calls[0]
1583             assert (
1584                 call_args[0].lower() == str(pyproject_path).lower()
1585             ), "Incorrect config loaded."
1586
1587     @pytest.mark.incompatible_with_mypyc
1588     def test_code_option_parent_config(self) -> None:
1589         """
1590         Test that the code option finds the pyproject.toml in the parent directory.
1591         """
1592         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1593             with change_directory(THIS_DIR):
1594                 args = ["--code", "print"]
1595                 CliRunner().invoke(black.main, args)
1596
1597                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1598                 assert (
1599                     len(parse.mock_calls) >= 1
1600                 ), "Expected config parse to be called with the current directory."
1601
1602                 _, call_args, _ = parse.mock_calls[0]
1603                 assert (
1604                     call_args[0].lower() == str(pyproject_path).lower()
1605                 ), "Incorrect config loaded."
1606
1607     def test_for_handled_unexpected_eof_error(self) -> None:
1608         """
1609         Test that an unexpected EOF SyntaxError is nicely presented.
1610         """
1611         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1612             black.lib2to3_parse("print(", {})
1613
1614         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1615
1616     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1617         with pytest.raises(AssertionError) as err:
1618             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1619
1620         err.match("--safe")
1621         # Unfortunately the SyntaxError message has changed in newer versions so we
1622         # can't match it directly.
1623         err.match("invalid character")
1624         err.match(r"\(<unknown>, line 1\)")
1625
1626
1627 class TestCaching:
1628     def test_get_cache_dir(
1629         self,
1630         tmp_path: Path,
1631         monkeypatch: pytest.MonkeyPatch,
1632     ) -> None:
1633         # Create multiple cache directories
1634         workspace1 = tmp_path / "ws1"
1635         workspace1.mkdir()
1636         workspace2 = tmp_path / "ws2"
1637         workspace2.mkdir()
1638
1639         # Force user_cache_dir to use the temporary directory for easier assertions
1640         patch_user_cache_dir = patch(
1641             target="black.cache.user_cache_dir",
1642             autospec=True,
1643             return_value=str(workspace1),
1644         )
1645
1646         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1647         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1648         with patch_user_cache_dir:
1649             assert get_cache_dir() == workspace1
1650
1651         # If it is set, use the path provided in the env var.
1652         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1653         assert get_cache_dir() == workspace2
1654
1655     def test_cache_broken_file(self) -> None:
1656         mode = DEFAULT_MODE
1657         with cache_dir() as workspace:
1658             cache_file = get_cache_file(mode)
1659             cache_file.write_text("this is not a pickle")
1660             assert black.read_cache(mode) == {}
1661             src = (workspace / "test.py").resolve()
1662             src.write_text("print('hello')")
1663             invokeBlack([str(src)])
1664             cache = black.read_cache(mode)
1665             assert str(src) in cache
1666
1667     def test_cache_single_file_already_cached(self) -> None:
1668         mode = DEFAULT_MODE
1669         with cache_dir() as workspace:
1670             src = (workspace / "test.py").resolve()
1671             src.write_text("print('hello')")
1672             black.write_cache({}, [src], mode)
1673             invokeBlack([str(src)])
1674             assert src.read_text() == "print('hello')"
1675
1676     @event_loop()
1677     def test_cache_multiple_files(self) -> None:
1678         mode = DEFAULT_MODE
1679         with cache_dir() as workspace, patch(
1680             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1681         ):
1682             one = (workspace / "one.py").resolve()
1683             with one.open("w") as fobj:
1684                 fobj.write("print('hello')")
1685             two = (workspace / "two.py").resolve()
1686             with two.open("w") as fobj:
1687                 fobj.write("print('hello')")
1688             black.write_cache({}, [one], mode)
1689             invokeBlack([str(workspace)])
1690             with one.open("r") as fobj:
1691                 assert fobj.read() == "print('hello')"
1692             with two.open("r") as fobj:
1693                 assert fobj.read() == 'print("hello")\n'
1694             cache = black.read_cache(mode)
1695             assert str(one) in cache
1696             assert str(two) in cache
1697
1698     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1699     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1700         mode = DEFAULT_MODE
1701         with cache_dir() as workspace:
1702             src = (workspace / "test.py").resolve()
1703             with src.open("w") as fobj:
1704                 fobj.write("print('hello')")
1705             with patch("black.read_cache") as read_cache, patch(
1706                 "black.write_cache"
1707             ) as write_cache:
1708                 cmd = [str(src), "--diff"]
1709                 if color:
1710                     cmd.append("--color")
1711                 invokeBlack(cmd)
1712                 cache_file = get_cache_file(mode)
1713                 assert cache_file.exists() is False
1714                 write_cache.assert_not_called()
1715                 read_cache.assert_not_called()
1716
1717     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1718     @event_loop()
1719     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1720         with cache_dir() as workspace:
1721             for tag in range(0, 4):
1722                 src = (workspace / f"test{tag}.py").resolve()
1723                 with src.open("w") as fobj:
1724                     fobj.write("print('hello')")
1725             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1726                 cmd = ["--diff", str(workspace)]
1727                 if color:
1728                     cmd.append("--color")
1729                 invokeBlack(cmd, exit_code=0)
1730                 # this isn't quite doing what we want, but if it _isn't_
1731                 # called then we cannot be using the lock it provides
1732                 mgr.assert_called()
1733
1734     def test_no_cache_when_stdin(self) -> None:
1735         mode = DEFAULT_MODE
1736         with cache_dir():
1737             result = CliRunner().invoke(
1738                 black.main, ["-"], input=BytesIO(b"print('hello')")
1739             )
1740             assert not result.exit_code
1741             cache_file = get_cache_file(mode)
1742             assert not cache_file.exists()
1743
1744     def test_read_cache_no_cachefile(self) -> None:
1745         mode = DEFAULT_MODE
1746         with cache_dir():
1747             assert black.read_cache(mode) == {}
1748
1749     def test_write_cache_read_cache(self) -> None:
1750         mode = DEFAULT_MODE
1751         with cache_dir() as workspace:
1752             src = (workspace / "test.py").resolve()
1753             src.touch()
1754             black.write_cache({}, [src], mode)
1755             cache = black.read_cache(mode)
1756             assert str(src) in cache
1757             assert cache[str(src)] == black.get_cache_info(src)
1758
1759     def test_filter_cached(self) -> None:
1760         with TemporaryDirectory() as workspace:
1761             path = Path(workspace)
1762             uncached = (path / "uncached").resolve()
1763             cached = (path / "cached").resolve()
1764             cached_but_changed = (path / "changed").resolve()
1765             uncached.touch()
1766             cached.touch()
1767             cached_but_changed.touch()
1768             cache = {
1769                 str(cached): black.get_cache_info(cached),
1770                 str(cached_but_changed): (0.0, 0),
1771             }
1772             todo, done = black.filter_cached(
1773                 cache, {uncached, cached, cached_but_changed}
1774             )
1775             assert todo == {uncached, cached_but_changed}
1776             assert done == {cached}
1777
1778     def test_write_cache_creates_directory_if_needed(self) -> None:
1779         mode = DEFAULT_MODE
1780         with cache_dir(exists=False) as workspace:
1781             assert not workspace.exists()
1782             black.write_cache({}, [], mode)
1783             assert workspace.exists()
1784
1785     @event_loop()
1786     def test_failed_formatting_does_not_get_cached(self) -> None:
1787         mode = DEFAULT_MODE
1788         with cache_dir() as workspace, patch(
1789             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1790         ):
1791             failing = (workspace / "failing.py").resolve()
1792             with failing.open("w") as fobj:
1793                 fobj.write("not actually python")
1794             clean = (workspace / "clean.py").resolve()
1795             with clean.open("w") as fobj:
1796                 fobj.write('print("hello")\n')
1797             invokeBlack([str(workspace)], exit_code=123)
1798             cache = black.read_cache(mode)
1799             assert str(failing) not in cache
1800             assert str(clean) in cache
1801
1802     def test_write_cache_write_fail(self) -> None:
1803         mode = DEFAULT_MODE
1804         with cache_dir(), patch.object(Path, "open") as mock:
1805             mock.side_effect = OSError
1806             black.write_cache({}, [], mode)
1807
1808     def test_read_cache_line_lengths(self) -> None:
1809         mode = DEFAULT_MODE
1810         short_mode = replace(DEFAULT_MODE, line_length=1)
1811         with cache_dir() as workspace:
1812             path = (workspace / "file.py").resolve()
1813             path.touch()
1814             black.write_cache({}, [path], mode)
1815             one = black.read_cache(mode)
1816             assert str(path) in one
1817             two = black.read_cache(short_mode)
1818             assert str(path) not in two
1819
1820
1821 def assert_collected_sources(
1822     src: Sequence[Union[str, Path]],
1823     expected: Sequence[Union[str, Path]],
1824     *,
1825     ctx: Optional[FakeContext] = None,
1826     exclude: Optional[str] = None,
1827     include: Optional[str] = None,
1828     extend_exclude: Optional[str] = None,
1829     force_exclude: Optional[str] = None,
1830     stdin_filename: Optional[str] = None,
1831 ) -> None:
1832     gs_src = tuple(str(Path(s)) for s in src)
1833     gs_expected = [Path(s) for s in expected]
1834     gs_exclude = None if exclude is None else compile_pattern(exclude)
1835     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1836     gs_extend_exclude = (
1837         None if extend_exclude is None else compile_pattern(extend_exclude)
1838     )
1839     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1840     collected = black.get_sources(
1841         ctx=ctx or FakeContext(),
1842         src=gs_src,
1843         quiet=False,
1844         verbose=False,
1845         include=gs_include,
1846         exclude=gs_exclude,
1847         extend_exclude=gs_extend_exclude,
1848         force_exclude=gs_force_exclude,
1849         report=black.Report(),
1850         stdin_filename=stdin_filename,
1851     )
1852     assert sorted(collected) == sorted(gs_expected)
1853
1854
1855 class TestFileCollection:
1856     def test_include_exclude(self) -> None:
1857         path = THIS_DIR / "data" / "include_exclude_tests"
1858         src = [path]
1859         expected = [
1860             Path(path / "b/dont_exclude/a.py"),
1861             Path(path / "b/dont_exclude/a.pyi"),
1862         ]
1863         assert_collected_sources(
1864             src,
1865             expected,
1866             include=r"\.pyi?$",
1867             exclude=r"/exclude/|/\.definitely_exclude/",
1868         )
1869
1870     def test_gitignore_used_as_default(self) -> None:
1871         base = Path(DATA_DIR / "include_exclude_tests")
1872         expected = [
1873             base / "b/.definitely_exclude/a.py",
1874             base / "b/.definitely_exclude/a.pyi",
1875         ]
1876         src = [base / "b/"]
1877         ctx = FakeContext()
1878         ctx.obj["root"] = base
1879         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1880
1881     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1882     def test_exclude_for_issue_1572(self) -> None:
1883         # Exclude shouldn't touch files that were explicitly given to Black through the
1884         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1885         # https://github.com/psf/black/issues/1572
1886         path = DATA_DIR / "include_exclude_tests"
1887         src = [path / "b/exclude/a.py"]
1888         expected = [path / "b/exclude/a.py"]
1889         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1890
1891     def test_gitignore_exclude(self) -> None:
1892         path = THIS_DIR / "data" / "include_exclude_tests"
1893         include = re.compile(r"\.pyi?$")
1894         exclude = re.compile(r"")
1895         report = black.Report()
1896         gitignore = PathSpec.from_lines(
1897             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1898         )
1899         sources: List[Path] = []
1900         expected = [
1901             Path(path / "b/dont_exclude/a.py"),
1902             Path(path / "b/dont_exclude/a.pyi"),
1903         ]
1904         this_abs = THIS_DIR.resolve()
1905         sources.extend(
1906             black.gen_python_files(
1907                 path.iterdir(),
1908                 this_abs,
1909                 include,
1910                 exclude,
1911                 None,
1912                 None,
1913                 report,
1914                 gitignore,
1915                 verbose=False,
1916                 quiet=False,
1917             )
1918         )
1919         assert sorted(expected) == sorted(sources)
1920
1921     def test_nested_gitignore(self) -> None:
1922         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1923         include = re.compile(r"\.pyi?$")
1924         exclude = re.compile(r"")
1925         root_gitignore = black.files.get_gitignore(path)
1926         report = black.Report()
1927         expected: List[Path] = [
1928             Path(path / "x.py"),
1929             Path(path / "root/b.py"),
1930             Path(path / "root/c.py"),
1931             Path(path / "root/child/c.py"),
1932         ]
1933         this_abs = THIS_DIR.resolve()
1934         sources = list(
1935             black.gen_python_files(
1936                 path.iterdir(),
1937                 this_abs,
1938                 include,
1939                 exclude,
1940                 None,
1941                 None,
1942                 report,
1943                 root_gitignore,
1944                 verbose=False,
1945                 quiet=False,
1946             )
1947         )
1948         assert sorted(expected) == sorted(sources)
1949
1950     def test_invalid_gitignore(self) -> None:
1951         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1952         empty_config = path / "pyproject.toml"
1953         result = BlackRunner().invoke(
1954             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1955         )
1956         assert result.exit_code == 1
1957         assert result.stderr_bytes is not None
1958
1959         gitignore = path / ".gitignore"
1960         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1961
1962     def test_invalid_nested_gitignore(self) -> None:
1963         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1964         empty_config = path / "pyproject.toml"
1965         result = BlackRunner().invoke(
1966             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1967         )
1968         assert result.exit_code == 1
1969         assert result.stderr_bytes is not None
1970
1971         gitignore = path / "a" / ".gitignore"
1972         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1973
1974     def test_empty_include(self) -> None:
1975         path = DATA_DIR / "include_exclude_tests"
1976         src = [path]
1977         expected = [
1978             Path(path / "b/exclude/a.pie"),
1979             Path(path / "b/exclude/a.py"),
1980             Path(path / "b/exclude/a.pyi"),
1981             Path(path / "b/dont_exclude/a.pie"),
1982             Path(path / "b/dont_exclude/a.py"),
1983             Path(path / "b/dont_exclude/a.pyi"),
1984             Path(path / "b/.definitely_exclude/a.pie"),
1985             Path(path / "b/.definitely_exclude/a.py"),
1986             Path(path / "b/.definitely_exclude/a.pyi"),
1987             Path(path / ".gitignore"),
1988             Path(path / "pyproject.toml"),
1989         ]
1990         # Setting exclude explicitly to an empty string to block .gitignore usage.
1991         assert_collected_sources(src, expected, include="", exclude="")
1992
1993     def test_extend_exclude(self) -> None:
1994         path = DATA_DIR / "include_exclude_tests"
1995         src = [path]
1996         expected = [
1997             Path(path / "b/exclude/a.py"),
1998             Path(path / "b/dont_exclude/a.py"),
1999         ]
2000         assert_collected_sources(
2001             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2002         )
2003
2004     @pytest.mark.incompatible_with_mypyc
2005     def test_symlink_out_of_root_directory(self) -> None:
2006         path = MagicMock()
2007         root = THIS_DIR.resolve()
2008         child = MagicMock()
2009         include = re.compile(black.DEFAULT_INCLUDES)
2010         exclude = re.compile(black.DEFAULT_EXCLUDES)
2011         report = black.Report()
2012         gitignore = PathSpec.from_lines("gitwildmatch", [])
2013         # `child` should behave like a symlink which resolved path is clearly
2014         # outside of the `root` directory.
2015         path.iterdir.return_value = [child]
2016         child.resolve.return_value = Path("/a/b/c")
2017         child.as_posix.return_value = "/a/b/c"
2018         try:
2019             list(
2020                 black.gen_python_files(
2021                     path.iterdir(),
2022                     root,
2023                     include,
2024                     exclude,
2025                     None,
2026                     None,
2027                     report,
2028                     gitignore,
2029                     verbose=False,
2030                     quiet=False,
2031                 )
2032             )
2033         except ValueError as ve:
2034             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2035         path.iterdir.assert_called_once()
2036         child.resolve.assert_called_once()
2037
2038     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2039     def test_get_sources_with_stdin(self) -> None:
2040         src = ["-"]
2041         expected = ["-"]
2042         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2043
2044     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2045     def test_get_sources_with_stdin_filename(self) -> None:
2046         src = ["-"]
2047         stdin_filename = str(THIS_DIR / "data/collections.py")
2048         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2049         assert_collected_sources(
2050             src,
2051             expected,
2052             exclude=r"/exclude/a\.py",
2053             stdin_filename=stdin_filename,
2054         )
2055
2056     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2057     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2058         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2059         # file being passed directly. This is the same as
2060         # test_exclude_for_issue_1572
2061         path = DATA_DIR / "include_exclude_tests"
2062         src = ["-"]
2063         stdin_filename = str(path / "b/exclude/a.py")
2064         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2065         assert_collected_sources(
2066             src,
2067             expected,
2068             exclude=r"/exclude/|a\.py",
2069             stdin_filename=stdin_filename,
2070         )
2071
2072     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2073     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2074         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2075         # file being passed directly. This is the same as
2076         # test_exclude_for_issue_1572
2077         src = ["-"]
2078         path = THIS_DIR / "data" / "include_exclude_tests"
2079         stdin_filename = str(path / "b/exclude/a.py")
2080         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2081         assert_collected_sources(
2082             src,
2083             expected,
2084             extend_exclude=r"/exclude/|a\.py",
2085             stdin_filename=stdin_filename,
2086         )
2087
2088     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2089     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2090         # Force exclude should exclude the file when passing it through
2091         # stdin_filename
2092         path = THIS_DIR / "data" / "include_exclude_tests"
2093         stdin_filename = str(path / "b/exclude/a.py")
2094         assert_collected_sources(
2095             src=["-"],
2096             expected=[],
2097             force_exclude=r"/exclude/|a\.py",
2098             stdin_filename=stdin_filename,
2099         )
2100
2101
2102 try:
2103     with open(black.__file__, "r", encoding="utf-8") as _bf:
2104         black_source_lines = _bf.readlines()
2105 except UnicodeDecodeError:
2106     if not black.COMPILED:
2107         raise
2108
2109
2110 def tracefunc(
2111     frame: types.FrameType, event: str, arg: Any
2112 ) -> Callable[[types.FrameType, str, Any], Any]:
2113     """Show function calls `from black/__init__.py` as they happen.
2114
2115     Register this with `sys.settrace()` in a test you're debugging.
2116     """
2117     if event != "call":
2118         return tracefunc
2119
2120     stack = len(inspect.stack()) - 19
2121     stack *= 2
2122     filename = frame.f_code.co_filename
2123     lineno = frame.f_lineno
2124     func_sig_lineno = lineno - 1
2125     funcname = black_source_lines[func_sig_lineno].strip()
2126     while funcname.startswith("@"):
2127         func_sig_lineno += 1
2128         funcname = black_source_lines[func_sig_lineno].strip()
2129     if "black/__init__.py" in filename:
2130         print(f"{' ' * stack}{lineno}:{funcname}")
2131     return tracefunc