]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump peter-evans/find-comment from 1.3.0 to 2 (#2960)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
67 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
68 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
69 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
70 T = TypeVar("T")
71 R = TypeVar("R")
72
73 # Match the time output in a diff, but nothing else
74 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.cache.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop() -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     loop = policy.new_event_loop()
91     asyncio.set_event_loop(loop)
92     try:
93         yield
94
95     finally:
96         loop.close()
97
98
99 class FakeContext(click.Context):
100     """A fake click Context for when calling functions that need it."""
101
102     def __init__(self) -> None:
103         self.default_map: Dict[str, Any] = {}
104         # Dummy root, since most of the tests don't care about it
105         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
106
107
108 class FakeParameter(click.Parameter):
109     """A fake click Parameter for when calling functions that need it."""
110
111     def __init__(self) -> None:
112         pass
113
114
115 class BlackRunner(CliRunner):
116     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
117
118     def __init__(self) -> None:
119         super().__init__(mix_stderr=False)
120
121
122 def invokeBlack(
123     args: List[str], exit_code: int = 0, ignore_config: bool = True
124 ) -> None:
125     runner = BlackRunner()
126     if ignore_config:
127         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
128     result = runner.invoke(black.main, args, catch_exceptions=False)
129     assert result.stdout_bytes is not None
130     assert result.stderr_bytes is not None
131     msg = (
132         f"Failed with args: {args}\n"
133         f"stdout: {result.stdout_bytes.decode()!r}\n"
134         f"stderr: {result.stderr_bytes.decode()!r}\n"
135         f"exception: {result.exception}"
136     )
137     assert result.exit_code == exit_code, msg
138
139
140 class BlackTestCase(BlackBaseTestCase):
141     invokeBlack = staticmethod(invokeBlack)
142
143     def test_empty_ff(self) -> None:
144         expected = ""
145         tmp_file = Path(black.dump_to_file())
146         try:
147             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
148             with open(tmp_file, encoding="utf8") as f:
149                 actual = f.read()
150         finally:
151             os.unlink(tmp_file)
152         self.assertFormatEqual(expected, actual)
153
154     def test_experimental_string_processing_warns(self) -> None:
155         self.assertWarns(
156             black.mode.Deprecated, black.Mode, experimental_string_processing=True
157         )
158
159     def test_piping(self) -> None:
160         source, expected = read_data("src/black/__init__", data=False)
161         result = BlackRunner().invoke(
162             black.main,
163             [
164                 "-",
165                 "--fast",
166                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
167                 f"--config={EMPTY_CONFIG}",
168             ],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         if source != result.output:
174             black.assert_equivalent(source, result.output)
175             black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("expression.py")
183         expected, _ = read_data("expression.diff")
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={EMPTY_CONFIG}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         args = [
202             "-",
203             "--fast",
204             f"--line-length={black.DEFAULT_LINE_LENGTH}",
205             "--diff",
206             "--color",
207             f"--config={EMPTY_CONFIG}",
208         ]
209         result = BlackRunner().invoke(
210             black.main, args, input=BytesIO(source.encode("utf8"))
211         )
212         actual = result.output
213         # Again, the contents are checked in a different test, so only look for colors.
214         self.assertIn("\033[1m", actual)
215         self.assertIn("\033[36m", actual)
216         self.assertIn("\033[32m", actual)
217         self.assertIn("\033[31m", actual)
218         self.assertIn("\033[0m", actual)
219
220     @patch("black.dump_to_file", dump_to_stderr)
221     def _test_wip(self) -> None:
222         source, expected = read_data("wip")
223         sys.settrace(tracefunc)
224         mode = replace(
225             DEFAULT_MODE,
226             experimental_string_processing=False,
227             target_versions={black.TargetVersion.PY38},
228         )
229         actual = fs(source, mode=mode)
230         sys.settrace(None)
231         self.assertFormatEqual(expected, actual)
232         black.assert_equivalent(source, actual)
233         black.assert_stable(source, actual, black.FileMode())
234
235     def test_pep_572_version_detection(self) -> None:
236         source, _ = read_data("pep_572")
237         root = black.lib2to3_parse(source)
238         features = black.get_features_used(root)
239         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
240         versions = black.detect_target_versions(root)
241         self.assertIn(black.TargetVersion.PY38, versions)
242
243     def test_expression_ff(self) -> None:
244         source, expected = read_data("expression")
245         tmp_file = Path(black.dump_to_file(source))
246         try:
247             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
248             with open(tmp_file, encoding="utf8") as f:
249                 actual = f.read()
250         finally:
251             os.unlink(tmp_file)
252         self.assertFormatEqual(expected, actual)
253         with patch("black.dump_to_file", dump_to_stderr):
254             black.assert_equivalent(source, actual)
255             black.assert_stable(source, actual, DEFAULT_MODE)
256
257     def test_expression_diff(self) -> None:
258         source, _ = read_data("expression.py")
259         expected, _ = read_data("expression.diff")
260         tmp_file = Path(black.dump_to_file(source))
261         diff_header = re.compile(
262             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
263             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
264         )
265         try:
266             result = BlackRunner().invoke(
267                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
268             )
269             self.assertEqual(result.exit_code, 0)
270         finally:
271             os.unlink(tmp_file)
272         actual = result.output
273         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
274         if expected != actual:
275             dump = black.dump_to_file(actual)
276             msg = (
277                 "Expected diff isn't equal to the actual. If you made changes to"
278                 " expression.py and this is an anticipated difference, overwrite"
279                 f" tests/data/expression.diff with {dump}"
280             )
281             self.assertEqual(expected, actual, msg)
282
283     def test_expression_diff_with_color(self) -> None:
284         source, _ = read_data("expression.py")
285         expected, _ = read_data("expression.diff")
286         tmp_file = Path(black.dump_to_file(source))
287         try:
288             result = BlackRunner().invoke(
289                 black.main,
290                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
291             )
292         finally:
293             os.unlink(tmp_file)
294         actual = result.output
295         # We check the contents of the diff in `test_expression_diff`. All
296         # we need to check here is that color codes exist in the result.
297         self.assertIn("\033[1m", actual)
298         self.assertIn("\033[36m", actual)
299         self.assertIn("\033[32m", actual)
300         self.assertIn("\033[31m", actual)
301         self.assertIn("\033[0m", actual)
302
303     def test_detect_pos_only_arguments(self) -> None:
304         source, _ = read_data("pep_570")
305         root = black.lib2to3_parse(source)
306         features = black.get_features_used(root)
307         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
308         versions = black.detect_target_versions(root)
309         self.assertIn(black.TargetVersion.PY38, versions)
310
311     @patch("black.dump_to_file", dump_to_stderr)
312     def test_string_quotes(self) -> None:
313         source, expected = read_data("string_quotes")
314         mode = black.Mode(preview=True)
315         assert_format(source, expected, mode)
316         mode = replace(mode, string_normalization=False)
317         not_normalized = fs(source, mode=mode)
318         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
319         black.assert_equivalent(source, not_normalized)
320         black.assert_stable(source, not_normalized, mode=mode)
321
322     def test_skip_magic_trailing_comma(self) -> None:
323         source, _ = read_data("expression.py")
324         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         diff_header = re.compile(
327             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
328             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
329         )
330         try:
331             result = BlackRunner().invoke(
332                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
333             )
334             self.assertEqual(result.exit_code, 0)
335         finally:
336             os.unlink(tmp_file)
337         actual = result.output
338         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
339         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
340         if expected != actual:
341             dump = black.dump_to_file(actual)
342             msg = (
343                 "Expected diff isn't equal to the actual. If you made changes to"
344                 " expression.py and this is an anticipated difference, overwrite"
345                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
346             )
347             self.assertEqual(expected, actual, msg)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_async_as_identifier(self) -> None:
351         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
352         source, expected = read_data("async_as_identifier")
353         actual = fs(source)
354         self.assertFormatEqual(expected, actual)
355         major, minor = sys.version_info[:2]
356         if major < 3 or (major <= 3 and minor < 7):
357             black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, DEFAULT_MODE)
359         # ensure black can parse this when the target is 3.6
360         self.invokeBlack([str(source_path), "--target-version", "py36"])
361         # but not on 3.7, because async/await is no longer an identifier
362         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_python37(self) -> None:
366         source_path = (THIS_DIR / "data" / "python37.py").resolve()
367         source, expected = read_data("python37")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         major, minor = sys.version_info[:2]
371         if major > 3 or (major == 3 and minor >= 7):
372             black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, DEFAULT_MODE)
374         # ensure black can parse this when the target is 3.7
375         self.invokeBlack([str(source_path), "--target-version", "py37"])
376         # but not on 3.6, because we use async as a reserved keyword
377         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
378
379     def test_tab_comment_indentation(self) -> None:
380         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
381         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
382         self.assertFormatEqual(contents_spc, fs(contents_spc))
383         self.assertFormatEqual(contents_spc, fs(contents_tab))
384
385         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
386         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
387         self.assertFormatEqual(contents_spc, fs(contents_spc))
388         self.assertFormatEqual(contents_spc, fs(contents_tab))
389
390         # mixed tabs and spaces (valid Python 2 code)
391         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
392         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
393         self.assertFormatEqual(contents_spc, fs(contents_spc))
394         self.assertFormatEqual(contents_spc, fs(contents_tab))
395
396         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
397         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
398         self.assertFormatEqual(contents_spc, fs(contents_spc))
399         self.assertFormatEqual(contents_spc, fs(contents_tab))
400
401     def test_report_verbose(self) -> None:
402         report = Report(verbose=True)
403         out_lines = []
404         err_lines = []
405
406         def out(msg: str, **kwargs: Any) -> None:
407             out_lines.append(msg)
408
409         def err(msg: str, **kwargs: Any) -> None:
410             err_lines.append(msg)
411
412         with patch("black.output._out", out), patch("black.output._err", err):
413             report.done(Path("f1"), black.Changed.NO)
414             self.assertEqual(len(out_lines), 1)
415             self.assertEqual(len(err_lines), 0)
416             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
417             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
418             self.assertEqual(report.return_code, 0)
419             report.done(Path("f2"), black.Changed.YES)
420             self.assertEqual(len(out_lines), 2)
421             self.assertEqual(len(err_lines), 0)
422             self.assertEqual(out_lines[-1], "reformatted f2")
423             self.assertEqual(
424                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
425             )
426             report.done(Path("f3"), black.Changed.CACHED)
427             self.assertEqual(len(out_lines), 3)
428             self.assertEqual(len(err_lines), 0)
429             self.assertEqual(
430                 out_lines[-1], "f3 wasn't modified on disk since last run."
431             )
432             self.assertEqual(
433                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
434             )
435             self.assertEqual(report.return_code, 0)
436             report.check = True
437             self.assertEqual(report.return_code, 1)
438             report.check = False
439             report.failed(Path("e1"), "boom")
440             self.assertEqual(len(out_lines), 3)
441             self.assertEqual(len(err_lines), 1)
442             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
446                 " reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.done(Path("f3"), black.Changed.YES)
450             self.assertEqual(len(out_lines), 4)
451             self.assertEqual(len(err_lines), 1)
452             self.assertEqual(out_lines[-1], "reformatted f3")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
456                 " reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.failed(Path("e2"), "boom")
460             self.assertEqual(len(out_lines), 4)
461             self.assertEqual(len(err_lines), 2)
462             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
463             self.assertEqual(
464                 unstyle(str(report)),
465                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
466                 " reformat.",
467             )
468             self.assertEqual(report.return_code, 123)
469             report.path_ignored(Path("wat"), "no match")
470             self.assertEqual(len(out_lines), 5)
471             self.assertEqual(len(err_lines), 2)
472             self.assertEqual(out_lines[-1], "wat ignored: no match")
473             self.assertEqual(
474                 unstyle(str(report)),
475                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
476                 " reformat.",
477             )
478             self.assertEqual(report.return_code, 123)
479             report.done(Path("f4"), black.Changed.NO)
480             self.assertEqual(len(out_lines), 6)
481             self.assertEqual(len(err_lines), 2)
482             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
483             self.assertEqual(
484                 unstyle(str(report)),
485                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
486                 " reformat.",
487             )
488             self.assertEqual(report.return_code, 123)
489             report.check = True
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
493                 " would fail to reformat.",
494             )
495             report.check = False
496             report.diff = True
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
500                 " would fail to reformat.",
501             )
502
503     def test_report_quiet(self) -> None:
504         report = Report(quiet=True)
505         out_lines = []
506         err_lines = []
507
508         def out(msg: str, **kwargs: Any) -> None:
509             out_lines.append(msg)
510
511         def err(msg: str, **kwargs: Any) -> None:
512             err_lines.append(msg)
513
514         with patch("black.output._out", out), patch("black.output._err", err):
515             report.done(Path("f1"), black.Changed.NO)
516             self.assertEqual(len(out_lines), 0)
517             self.assertEqual(len(err_lines), 0)
518             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
519             self.assertEqual(report.return_code, 0)
520             report.done(Path("f2"), black.Changed.YES)
521             self.assertEqual(len(out_lines), 0)
522             self.assertEqual(len(err_lines), 0)
523             self.assertEqual(
524                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
525             )
526             report.done(Path("f3"), black.Changed.CACHED)
527             self.assertEqual(len(out_lines), 0)
528             self.assertEqual(len(err_lines), 0)
529             self.assertEqual(
530                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
531             )
532             self.assertEqual(report.return_code, 0)
533             report.check = True
534             self.assertEqual(report.return_code, 1)
535             report.check = False
536             report.failed(Path("e1"), "boom")
537             self.assertEqual(len(out_lines), 0)
538             self.assertEqual(len(err_lines), 1)
539             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
540             self.assertEqual(
541                 unstyle(str(report)),
542                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
543                 " reformat.",
544             )
545             self.assertEqual(report.return_code, 123)
546             report.done(Path("f3"), black.Changed.YES)
547             self.assertEqual(len(out_lines), 0)
548             self.assertEqual(len(err_lines), 1)
549             self.assertEqual(
550                 unstyle(str(report)),
551                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
552                 " reformat.",
553             )
554             self.assertEqual(report.return_code, 123)
555             report.failed(Path("e2"), "boom")
556             self.assertEqual(len(out_lines), 0)
557             self.assertEqual(len(err_lines), 2)
558             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
562                 " reformat.",
563             )
564             self.assertEqual(report.return_code, 123)
565             report.path_ignored(Path("wat"), "no match")
566             self.assertEqual(len(out_lines), 0)
567             self.assertEqual(len(err_lines), 2)
568             self.assertEqual(
569                 unstyle(str(report)),
570                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
571                 " reformat.",
572             )
573             self.assertEqual(report.return_code, 123)
574             report.done(Path("f4"), black.Changed.NO)
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 2)
577             self.assertEqual(
578                 unstyle(str(report)),
579                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
580                 " reformat.",
581             )
582             self.assertEqual(report.return_code, 123)
583             report.check = True
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
587                 " would fail to reformat.",
588             )
589             report.check = False
590             report.diff = True
591             self.assertEqual(
592                 unstyle(str(report)),
593                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
594                 " would fail to reformat.",
595             )
596
597     def test_report_normal(self) -> None:
598         report = black.Report()
599         out_lines = []
600         err_lines = []
601
602         def out(msg: str, **kwargs: Any) -> None:
603             out_lines.append(msg)
604
605         def err(msg: str, **kwargs: Any) -> None:
606             err_lines.append(msg)
607
608         with patch("black.output._out", out), patch("black.output._err", err):
609             report.done(Path("f1"), black.Changed.NO)
610             self.assertEqual(len(out_lines), 0)
611             self.assertEqual(len(err_lines), 0)
612             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
613             self.assertEqual(report.return_code, 0)
614             report.done(Path("f2"), black.Changed.YES)
615             self.assertEqual(len(out_lines), 1)
616             self.assertEqual(len(err_lines), 0)
617             self.assertEqual(out_lines[-1], "reformatted f2")
618             self.assertEqual(
619                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
620             )
621             report.done(Path("f3"), black.Changed.CACHED)
622             self.assertEqual(len(out_lines), 1)
623             self.assertEqual(len(err_lines), 0)
624             self.assertEqual(out_lines[-1], "reformatted f2")
625             self.assertEqual(
626                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
627             )
628             self.assertEqual(report.return_code, 0)
629             report.check = True
630             self.assertEqual(report.return_code, 1)
631             report.check = False
632             report.failed(Path("e1"), "boom")
633             self.assertEqual(len(out_lines), 1)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
639                 " reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.done(Path("f3"), black.Changed.YES)
643             self.assertEqual(len(out_lines), 2)
644             self.assertEqual(len(err_lines), 1)
645             self.assertEqual(out_lines[-1], "reformatted f3")
646             self.assertEqual(
647                 unstyle(str(report)),
648                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
649                 " reformat.",
650             )
651             self.assertEqual(report.return_code, 123)
652             report.failed(Path("e2"), "boom")
653             self.assertEqual(len(out_lines), 2)
654             self.assertEqual(len(err_lines), 2)
655             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
656             self.assertEqual(
657                 unstyle(str(report)),
658                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
659                 " reformat.",
660             )
661             self.assertEqual(report.return_code, 123)
662             report.path_ignored(Path("wat"), "no match")
663             self.assertEqual(len(out_lines), 2)
664             self.assertEqual(len(err_lines), 2)
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
668                 " reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.done(Path("f4"), black.Changed.NO)
672             self.assertEqual(len(out_lines), 2)
673             self.assertEqual(len(err_lines), 2)
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
677                 " reformat.",
678             )
679             self.assertEqual(report.return_code, 123)
680             report.check = True
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
684                 " would fail to reformat.",
685             )
686             report.check = False
687             report.diff = True
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
691                 " would fail to reformat.",
692             )
693
694     def test_lib2to3_parse(self) -> None:
695         with self.assertRaises(black.InvalidInput):
696             black.lib2to3_parse("invalid syntax")
697
698         straddling = "x + y"
699         black.lib2to3_parse(straddling)
700         black.lib2to3_parse(straddling, {TargetVersion.PY36})
701
702         py2_only = "print x"
703         with self.assertRaises(black.InvalidInput):
704             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
705
706         py3_only = "exec(x, end=y)"
707         black.lib2to3_parse(py3_only)
708         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
709
710     def test_get_features_used_decorator(self) -> None:
711         # Test the feature detection of new decorator syntax
712         # since this makes some test cases of test_get_features_used()
713         # fails if it fails, this is tested first so that a useful case
714         # is identified
715         simples, relaxed = read_data("decorators")
716         # skip explanation comments at the top of the file
717         for simple_test in simples.split("##")[1:]:
718             node = black.lib2to3_parse(simple_test)
719             decorator = str(node.children[0].children[0]).strip()
720             self.assertNotIn(
721                 Feature.RELAXED_DECORATORS,
722                 black.get_features_used(node),
723                 msg=(
724                     f"decorator '{decorator}' follows python<=3.8 syntax"
725                     "but is detected as 3.9+"
726                     # f"The full node is\n{node!r}"
727                 ),
728             )
729         # skip the '# output' comment at the top of the output part
730         for relaxed_test in relaxed.split("##")[1:]:
731             node = black.lib2to3_parse(relaxed_test)
732             decorator = str(node.children[0].children[0]).strip()
733             self.assertIn(
734                 Feature.RELAXED_DECORATORS,
735                 black.get_features_used(node),
736                 msg=(
737                     f"decorator '{decorator}' uses python3.9+ syntax"
738                     "but is detected as python<=3.8"
739                     # f"The full node is\n{node!r}"
740                 ),
741             )
742
743     def test_get_features_used(self) -> None:
744         node = black.lib2to3_parse("def f(*, arg): ...\n")
745         self.assertEqual(black.get_features_used(node), set())
746         node = black.lib2to3_parse("def f(*, arg,): ...\n")
747         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
748         node = black.lib2to3_parse("f(*arg,)\n")
749         self.assertEqual(
750             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
751         )
752         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
753         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
754         node = black.lib2to3_parse("123_456\n")
755         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
756         node = black.lib2to3_parse("123456\n")
757         self.assertEqual(black.get_features_used(node), set())
758         source, expected = read_data("function")
759         node = black.lib2to3_parse(source)
760         expected_features = {
761             Feature.TRAILING_COMMA_IN_CALL,
762             Feature.TRAILING_COMMA_IN_DEF,
763             Feature.F_STRINGS,
764         }
765         self.assertEqual(black.get_features_used(node), expected_features)
766         node = black.lib2to3_parse(expected)
767         self.assertEqual(black.get_features_used(node), expected_features)
768         source, expected = read_data("expression")
769         node = black.lib2to3_parse(source)
770         self.assertEqual(black.get_features_used(node), set())
771         node = black.lib2to3_parse(expected)
772         self.assertEqual(black.get_features_used(node), set())
773         node = black.lib2to3_parse("lambda a, /, b: ...")
774         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
775         node = black.lib2to3_parse("def fn(a, /, b): ...")
776         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
777         node = black.lib2to3_parse("def fn(): yield a, b")
778         self.assertEqual(black.get_features_used(node), set())
779         node = black.lib2to3_parse("def fn(): return a, b")
780         self.assertEqual(black.get_features_used(node), set())
781         node = black.lib2to3_parse("def fn(): yield *b, c")
782         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
783         node = black.lib2to3_parse("def fn(): return a, *b, c")
784         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
785         node = black.lib2to3_parse("x = a, *b, c")
786         self.assertEqual(black.get_features_used(node), set())
787         node = black.lib2to3_parse("x: Any = regular")
788         self.assertEqual(black.get_features_used(node), set())
789         node = black.lib2to3_parse("x: Any = (regular, regular)")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
794         self.assertEqual(
795             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
796         )
797
798     def test_get_features_used_for_future_flags(self) -> None:
799         for src, features in [
800             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
801             (
802                 "from __future__ import (other, annotations)",
803                 {Feature.FUTURE_ANNOTATIONS},
804             ),
805             ("a = 1 + 2\nfrom something import annotations", set()),
806             ("from __future__ import x, y", set()),
807         ]:
808             with self.subTest(src=src, features=features):
809                 node = black.lib2to3_parse(src)
810                 future_imports = black.get_future_imports(node)
811                 self.assertEqual(
812                     black.get_features_used(node, future_imports=future_imports),
813                     features,
814                 )
815
816     def test_get_future_imports(self) -> None:
817         node = black.lib2to3_parse("\n")
818         self.assertEqual(set(), black.get_future_imports(node))
819         node = black.lib2to3_parse("from __future__ import black\n")
820         self.assertEqual({"black"}, black.get_future_imports(node))
821         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
822         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
823         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
824         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
825         node = black.lib2to3_parse(
826             "from __future__ import multiple\nfrom __future__ import imports\n"
827         )
828         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
829         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
830         self.assertEqual({"black"}, black.get_future_imports(node))
831         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
832         self.assertEqual({"black"}, black.get_future_imports(node))
833         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from some.module import black\n")
836         self.assertEqual(set(), black.get_future_imports(node))
837         node = black.lib2to3_parse(
838             "from __future__ import unicode_literals as _unicode_literals"
839         )
840         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import unicode_literals as _lol, print"
843         )
844         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
845
846     @pytest.mark.incompatible_with_mypyc
847     def test_debug_visitor(self) -> None:
848         source, _ = read_data("debug_visitor.py")
849         expected, _ = read_data("debug_visitor.out")
850         out_lines = []
851         err_lines = []
852
853         def out(msg: str, **kwargs: Any) -> None:
854             out_lines.append(msg)
855
856         def err(msg: str, **kwargs: Any) -> None:
857             err_lines.append(msg)
858
859         with patch("black.debug.out", out):
860             DebugVisitor.show(source)
861         actual = "\n".join(out_lines) + "\n"
862         log_name = ""
863         if expected != actual:
864             log_name = black.dump_to_file(*out_lines)
865         self.assertEqual(
866             expected,
867             actual,
868             f"AST print out is different. Actual version dumped to {log_name}",
869         )
870
871     def test_format_file_contents(self) -> None:
872         empty = ""
873         mode = DEFAULT_MODE
874         with self.assertRaises(black.NothingChanged):
875             black.format_file_contents(empty, mode=mode, fast=False)
876         just_nl = "\n"
877         with self.assertRaises(black.NothingChanged):
878             black.format_file_contents(just_nl, mode=mode, fast=False)
879         same = "j = [1, 2, 3]\n"
880         with self.assertRaises(black.NothingChanged):
881             black.format_file_contents(same, mode=mode, fast=False)
882         different = "j = [1,2,3]"
883         expected = same
884         actual = black.format_file_contents(different, mode=mode, fast=False)
885         self.assertEqual(expected, actual)
886         invalid = "return if you can"
887         with self.assertRaises(black.InvalidInput) as e:
888             black.format_file_contents(invalid, mode=mode, fast=False)
889         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
890
891     def test_endmarker(self) -> None:
892         n = black.lib2to3_parse("\n")
893         self.assertEqual(n.type, black.syms.file_input)
894         self.assertEqual(len(n.children), 1)
895         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
896
897     @pytest.mark.incompatible_with_mypyc
898     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
899     def test_assertFormatEqual(self) -> None:
900         out_lines = []
901         err_lines = []
902
903         def out(msg: str, **kwargs: Any) -> None:
904             out_lines.append(msg)
905
906         def err(msg: str, **kwargs: Any) -> None:
907             err_lines.append(msg)
908
909         with patch("black.output._out", out), patch("black.output._err", err):
910             with self.assertRaises(AssertionError):
911                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
912
913         out_str = "".join(out_lines)
914         self.assertIn("Expected tree:", out_str)
915         self.assertIn("Actual tree:", out_str)
916         self.assertEqual("".join(err_lines), "")
917
918     @event_loop()
919     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
920     def test_works_in_mono_process_only_environment(self) -> None:
921         with cache_dir() as workspace:
922             for f in [
923                 (workspace / "one.py").resolve(),
924                 (workspace / "two.py").resolve(),
925             ]:
926                 f.write_text('print("hello")\n')
927             self.invokeBlack([str(workspace)])
928
929     @event_loop()
930     def test_check_diff_use_together(self) -> None:
931         with cache_dir():
932             # Files which will be reformatted.
933             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
934             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
935             # Files which will not be reformatted.
936             src2 = (THIS_DIR / "data" / "composition.py").resolve()
937             self.invokeBlack([str(src2), "--diff", "--check"])
938             # Multi file command.
939             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
940
941     def test_no_src_fails(self) -> None:
942         with cache_dir():
943             self.invokeBlack([], exit_code=1)
944
945     def test_src_and_code_fails(self) -> None:
946         with cache_dir():
947             self.invokeBlack([".", "-c", "0"], exit_code=1)
948
949     def test_broken_symlink(self) -> None:
950         with cache_dir() as workspace:
951             symlink = workspace / "broken_link.py"
952             try:
953                 symlink.symlink_to("nonexistent.py")
954             except (OSError, NotImplementedError) as e:
955                 self.skipTest(f"Can't create symlinks: {e}")
956             self.invokeBlack([str(workspace.resolve())])
957
958     def test_single_file_force_pyi(self) -> None:
959         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
960         contents, expected = read_data("force_pyi")
961         with cache_dir() as workspace:
962             path = (workspace / "file.py").resolve()
963             with open(path, "w") as fh:
964                 fh.write(contents)
965             self.invokeBlack([str(path), "--pyi"])
966             with open(path, "r") as fh:
967                 actual = fh.read()
968             # verify cache with --pyi is separate
969             pyi_cache = black.read_cache(pyi_mode)
970             self.assertIn(str(path), pyi_cache)
971             normal_cache = black.read_cache(DEFAULT_MODE)
972             self.assertNotIn(str(path), normal_cache)
973         self.assertFormatEqual(expected, actual)
974         black.assert_equivalent(contents, actual)
975         black.assert_stable(contents, actual, pyi_mode)
976
977     @event_loop()
978     def test_multi_file_force_pyi(self) -> None:
979         reg_mode = DEFAULT_MODE
980         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
981         contents, expected = read_data("force_pyi")
982         with cache_dir() as workspace:
983             paths = [
984                 (workspace / "file1.py").resolve(),
985                 (workspace / "file2.py").resolve(),
986             ]
987             for path in paths:
988                 with open(path, "w") as fh:
989                     fh.write(contents)
990             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
991             for path in paths:
992                 with open(path, "r") as fh:
993                     actual = fh.read()
994                 self.assertEqual(actual, expected)
995             # verify cache with --pyi is separate
996             pyi_cache = black.read_cache(pyi_mode)
997             normal_cache = black.read_cache(reg_mode)
998             for path in paths:
999                 self.assertIn(str(path), pyi_cache)
1000                 self.assertNotIn(str(path), normal_cache)
1001
1002     def test_pipe_force_pyi(self) -> None:
1003         source, expected = read_data("force_pyi")
1004         result = CliRunner().invoke(
1005             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1006         )
1007         self.assertEqual(result.exit_code, 0)
1008         actual = result.output
1009         self.assertFormatEqual(actual, expected)
1010
1011     def test_single_file_force_py36(self) -> None:
1012         reg_mode = DEFAULT_MODE
1013         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1014         source, expected = read_data("force_py36")
1015         with cache_dir() as workspace:
1016             path = (workspace / "file.py").resolve()
1017             with open(path, "w") as fh:
1018                 fh.write(source)
1019             self.invokeBlack([str(path), *PY36_ARGS])
1020             with open(path, "r") as fh:
1021                 actual = fh.read()
1022             # verify cache with --target-version is separate
1023             py36_cache = black.read_cache(py36_mode)
1024             self.assertIn(str(path), py36_cache)
1025             normal_cache = black.read_cache(reg_mode)
1026             self.assertNotIn(str(path), normal_cache)
1027         self.assertEqual(actual, expected)
1028
1029     @event_loop()
1030     def test_multi_file_force_py36(self) -> None:
1031         reg_mode = DEFAULT_MODE
1032         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1033         source, expected = read_data("force_py36")
1034         with cache_dir() as workspace:
1035             paths = [
1036                 (workspace / "file1.py").resolve(),
1037                 (workspace / "file2.py").resolve(),
1038             ]
1039             for path in paths:
1040                 with open(path, "w") as fh:
1041                     fh.write(source)
1042             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1043             for path in paths:
1044                 with open(path, "r") as fh:
1045                     actual = fh.read()
1046                 self.assertEqual(actual, expected)
1047             # verify cache with --target-version is separate
1048             pyi_cache = black.read_cache(py36_mode)
1049             normal_cache = black.read_cache(reg_mode)
1050             for path in paths:
1051                 self.assertIn(str(path), pyi_cache)
1052                 self.assertNotIn(str(path), normal_cache)
1053
1054     def test_pipe_force_py36(self) -> None:
1055         source, expected = read_data("force_py36")
1056         result = CliRunner().invoke(
1057             black.main,
1058             ["-", "-q", "--target-version=py36"],
1059             input=BytesIO(source.encode("utf8")),
1060         )
1061         self.assertEqual(result.exit_code, 0)
1062         actual = result.output
1063         self.assertFormatEqual(actual, expected)
1064
1065     @pytest.mark.incompatible_with_mypyc
1066     def test_reformat_one_with_stdin(self) -> None:
1067         with patch(
1068             "black.format_stdin_to_stdout",
1069             return_value=lambda *args, **kwargs: black.Changed.YES,
1070         ) as fsts:
1071             report = MagicMock()
1072             path = Path("-")
1073             black.reformat_one(
1074                 path,
1075                 fast=True,
1076                 write_back=black.WriteBack.YES,
1077                 mode=DEFAULT_MODE,
1078                 report=report,
1079             )
1080             fsts.assert_called_once()
1081             report.done.assert_called_with(path, black.Changed.YES)
1082
1083     @pytest.mark.incompatible_with_mypyc
1084     def test_reformat_one_with_stdin_filename(self) -> None:
1085         with patch(
1086             "black.format_stdin_to_stdout",
1087             return_value=lambda *args, **kwargs: black.Changed.YES,
1088         ) as fsts:
1089             report = MagicMock()
1090             p = "foo.py"
1091             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1092             expected = Path(p)
1093             black.reformat_one(
1094                 path,
1095                 fast=True,
1096                 write_back=black.WriteBack.YES,
1097                 mode=DEFAULT_MODE,
1098                 report=report,
1099             )
1100             fsts.assert_called_once_with(
1101                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1102             )
1103             # __BLACK_STDIN_FILENAME__ should have been stripped
1104             report.done.assert_called_with(expected, black.Changed.YES)
1105
1106     @pytest.mark.incompatible_with_mypyc
1107     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1108         with patch(
1109             "black.format_stdin_to_stdout",
1110             return_value=lambda *args, **kwargs: black.Changed.YES,
1111         ) as fsts:
1112             report = MagicMock()
1113             p = "foo.pyi"
1114             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1115             expected = Path(p)
1116             black.reformat_one(
1117                 path,
1118                 fast=True,
1119                 write_back=black.WriteBack.YES,
1120                 mode=DEFAULT_MODE,
1121                 report=report,
1122             )
1123             fsts.assert_called_once_with(
1124                 fast=True,
1125                 write_back=black.WriteBack.YES,
1126                 mode=replace(DEFAULT_MODE, is_pyi=True),
1127             )
1128             # __BLACK_STDIN_FILENAME__ should have been stripped
1129             report.done.assert_called_with(expected, black.Changed.YES)
1130
1131     @pytest.mark.incompatible_with_mypyc
1132     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1133         with patch(
1134             "black.format_stdin_to_stdout",
1135             return_value=lambda *args, **kwargs: black.Changed.YES,
1136         ) as fsts:
1137             report = MagicMock()
1138             p = "foo.ipynb"
1139             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1140             expected = Path(p)
1141             black.reformat_one(
1142                 path,
1143                 fast=True,
1144                 write_back=black.WriteBack.YES,
1145                 mode=DEFAULT_MODE,
1146                 report=report,
1147             )
1148             fsts.assert_called_once_with(
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1152             )
1153             # __BLACK_STDIN_FILENAME__ should have been stripped
1154             report.done.assert_called_with(expected, black.Changed.YES)
1155
1156     @pytest.mark.incompatible_with_mypyc
1157     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1158         with patch(
1159             "black.format_stdin_to_stdout",
1160             return_value=lambda *args, **kwargs: black.Changed.YES,
1161         ) as fsts:
1162             report = MagicMock()
1163             # Even with an existing file, since we are forcing stdin, black
1164             # should output to stdout and not modify the file inplace
1165             p = Path(str(THIS_DIR / "data/collections.py"))
1166             # Make sure is_file actually returns True
1167             self.assertTrue(p.is_file())
1168             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1169             expected = Path(p)
1170             black.reformat_one(
1171                 path,
1172                 fast=True,
1173                 write_back=black.WriteBack.YES,
1174                 mode=DEFAULT_MODE,
1175                 report=report,
1176             )
1177             fsts.assert_called_once()
1178             # __BLACK_STDIN_FILENAME__ should have been stripped
1179             report.done.assert_called_with(expected, black.Changed.YES)
1180
1181     def test_reformat_one_with_stdin_empty(self) -> None:
1182         output = io.StringIO()
1183         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1184             try:
1185                 black.format_stdin_to_stdout(
1186                     fast=True,
1187                     content="",
1188                     write_back=black.WriteBack.YES,
1189                     mode=DEFAULT_MODE,
1190                 )
1191             except io.UnsupportedOperation:
1192                 pass  # StringIO does not support detach
1193             assert output.getvalue() == ""
1194
1195     def test_invalid_cli_regex(self) -> None:
1196         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1197             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1198
1199     def test_required_version_matches_version(self) -> None:
1200         self.invokeBlack(
1201             ["--required-version", black.__version__, "-c", "0"],
1202             exit_code=0,
1203             ignore_config=True,
1204         )
1205
1206     def test_required_version_matches_partial_version(self) -> None:
1207         self.invokeBlack(
1208             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1209             exit_code=0,
1210             ignore_config=True,
1211         )
1212
1213     def test_required_version_does_not_match_on_minor_version(self) -> None:
1214         self.invokeBlack(
1215             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1216             exit_code=1,
1217             ignore_config=True,
1218         )
1219
1220     def test_required_version_does_not_match_version(self) -> None:
1221         result = BlackRunner().invoke(
1222             black.main,
1223             ["--required-version", "20.99b", "-c", "0"],
1224         )
1225         self.assertEqual(result.exit_code, 1)
1226         self.assertIn("required version", result.stderr)
1227
1228     def test_preserves_line_endings(self) -> None:
1229         with TemporaryDirectory() as workspace:
1230             test_file = Path(workspace) / "test.py"
1231             for nl in ["\n", "\r\n"]:
1232                 contents = nl.join(["def f(  ):", "    pass"])
1233                 test_file.write_bytes(contents.encode())
1234                 ff(test_file, write_back=black.WriteBack.YES)
1235                 updated_contents: bytes = test_file.read_bytes()
1236                 self.assertIn(nl.encode(), updated_contents)
1237                 if nl == "\n":
1238                     self.assertNotIn(b"\r\n", updated_contents)
1239
1240     def test_preserves_line_endings_via_stdin(self) -> None:
1241         for nl in ["\n", "\r\n"]:
1242             contents = nl.join(["def f(  ):", "    pass"])
1243             runner = BlackRunner()
1244             result = runner.invoke(
1245                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1246             )
1247             self.assertEqual(result.exit_code, 0)
1248             output = result.stdout_bytes
1249             self.assertIn(nl.encode("utf8"), output)
1250             if nl == "\n":
1251                 self.assertNotIn(b"\r\n", output)
1252
1253     def test_assert_equivalent_different_asts(self) -> None:
1254         with self.assertRaises(AssertionError):
1255             black.assert_equivalent("{}", "None")
1256
1257     def test_shhh_click(self) -> None:
1258         try:
1259             from click import _unicodefun
1260         except ImportError:
1261             self.skipTest("Incompatible Click version")
1262         if not hasattr(_unicodefun, "_verify_python3_env"):
1263             self.skipTest("Incompatible Click version")
1264         # First, let's see if Click is crashing with a preferred ASCII charset.
1265         with patch("locale.getpreferredencoding") as gpe:
1266             gpe.return_value = "ASCII"
1267             with self.assertRaises(RuntimeError):
1268                 _unicodefun._verify_python3_env()  # type: ignore
1269         # Now, let's silence Click...
1270         black.patch_click()
1271         # ...and confirm it's silent.
1272         with patch("locale.getpreferredencoding") as gpe:
1273             gpe.return_value = "ASCII"
1274             try:
1275                 _unicodefun._verify_python3_env()  # type: ignore
1276             except RuntimeError as re:
1277                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1278
1279     def test_root_logger_not_used_directly(self) -> None:
1280         def fail(*args: Any, **kwargs: Any) -> None:
1281             self.fail("Record created with root logger")
1282
1283         with patch.multiple(
1284             logging.root,
1285             debug=fail,
1286             info=fail,
1287             warning=fail,
1288             error=fail,
1289             critical=fail,
1290             log=fail,
1291         ):
1292             ff(THIS_DIR / "util.py")
1293
1294     def test_invalid_config_return_code(self) -> None:
1295         tmp_file = Path(black.dump_to_file())
1296         try:
1297             tmp_config = Path(black.dump_to_file())
1298             tmp_config.unlink()
1299             args = ["--config", str(tmp_config), str(tmp_file)]
1300             self.invokeBlack(args, exit_code=2, ignore_config=False)
1301         finally:
1302             tmp_file.unlink()
1303
1304     def test_parse_pyproject_toml(self) -> None:
1305         test_toml_file = THIS_DIR / "test.toml"
1306         config = black.parse_pyproject_toml(str(test_toml_file))
1307         self.assertEqual(config["verbose"], 1)
1308         self.assertEqual(config["check"], "no")
1309         self.assertEqual(config["diff"], "y")
1310         self.assertEqual(config["color"], True)
1311         self.assertEqual(config["line_length"], 79)
1312         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1313         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1314         self.assertEqual(config["exclude"], r"\.pyi?$")
1315         self.assertEqual(config["include"], r"\.py?$")
1316
1317     def test_read_pyproject_toml(self) -> None:
1318         test_toml_file = THIS_DIR / "test.toml"
1319         fake_ctx = FakeContext()
1320         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1321         config = fake_ctx.default_map
1322         self.assertEqual(config["verbose"], "1")
1323         self.assertEqual(config["check"], "no")
1324         self.assertEqual(config["diff"], "y")
1325         self.assertEqual(config["color"], "True")
1326         self.assertEqual(config["line_length"], "79")
1327         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1328         self.assertEqual(config["exclude"], r"\.pyi?$")
1329         self.assertEqual(config["include"], r"\.py?$")
1330
1331     @pytest.mark.incompatible_with_mypyc
1332     def test_find_project_root(self) -> None:
1333         with TemporaryDirectory() as workspace:
1334             root = Path(workspace)
1335             test_dir = root / "test"
1336             test_dir.mkdir()
1337
1338             src_dir = root / "src"
1339             src_dir.mkdir()
1340
1341             root_pyproject = root / "pyproject.toml"
1342             root_pyproject.touch()
1343             src_pyproject = src_dir / "pyproject.toml"
1344             src_pyproject.touch()
1345             src_python = src_dir / "foo.py"
1346             src_python.touch()
1347
1348             self.assertEqual(
1349                 black.find_project_root((src_dir, test_dir)),
1350                 (root.resolve(), "pyproject.toml"),
1351             )
1352             self.assertEqual(
1353                 black.find_project_root((src_dir,)),
1354                 (src_dir.resolve(), "pyproject.toml"),
1355             )
1356             self.assertEqual(
1357                 black.find_project_root((src_python,)),
1358                 (src_dir.resolve(), "pyproject.toml"),
1359             )
1360
1361     @patch(
1362         "black.files.find_user_pyproject_toml",
1363     )
1364     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1365         find_user_pyproject_toml.side_effect = RuntimeError()
1366
1367         with redirect_stderr(io.StringIO()) as stderr:
1368             result = black.files.find_pyproject_toml(
1369                 path_search_start=(str(Path.cwd().root),)
1370             )
1371
1372         assert result is None
1373         err = stderr.getvalue()
1374         assert "Ignoring user configuration" in err
1375
1376     @patch(
1377         "black.files.find_user_pyproject_toml",
1378         black.files.find_user_pyproject_toml.__wrapped__,
1379     )
1380     def test_find_user_pyproject_toml_linux(self) -> None:
1381         if system() == "Windows":
1382             return
1383
1384         # Test if XDG_CONFIG_HOME is checked
1385         with TemporaryDirectory() as workspace:
1386             tmp_user_config = Path(workspace) / "black"
1387             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1388                 self.assertEqual(
1389                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1390                 )
1391
1392         # Test fallback for XDG_CONFIG_HOME
1393         with patch.dict("os.environ"):
1394             os.environ.pop("XDG_CONFIG_HOME", None)
1395             fallback_user_config = Path("~/.config").expanduser() / "black"
1396             self.assertEqual(
1397                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1398             )
1399
1400     def test_find_user_pyproject_toml_windows(self) -> None:
1401         if system() != "Windows":
1402             return
1403
1404         user_config_path = Path.home() / ".black"
1405         self.assertEqual(
1406             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1407         )
1408
1409     def test_bpo_33660_workaround(self) -> None:
1410         if system() == "Windows":
1411             return
1412
1413         # https://bugs.python.org/issue33660
1414         root = Path("/")
1415         with change_directory(root):
1416             path = Path("workspace") / "project"
1417             report = black.Report(verbose=True)
1418             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1419             self.assertEqual(normalized_path, "workspace/project")
1420
1421     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1422         if system() != "Windows":
1423             return
1424
1425         with TemporaryDirectory() as workspace:
1426             root = Path(workspace)
1427             junction_dir = root / "junction"
1428             junction_target_outside_of_root = root / ".."
1429             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1430
1431             report = black.Report(verbose=True)
1432             normalized_path = black.normalize_path_maybe_ignore(
1433                 junction_dir, root, report
1434             )
1435             # Manually delete for Python < 3.8
1436             os.system(f"rmdir {junction_dir}")
1437
1438             self.assertEqual(normalized_path, None)
1439
1440     def test_newline_comment_interaction(self) -> None:
1441         source = "class A:\\\r\n# type: ignore\n pass\n"
1442         output = black.format_str(source, mode=DEFAULT_MODE)
1443         black.assert_stable(source, output, mode=DEFAULT_MODE)
1444
1445     def test_bpo_2142_workaround(self) -> None:
1446
1447         # https://bugs.python.org/issue2142
1448
1449         source, _ = read_data("missing_final_newline.py")
1450         # read_data adds a trailing newline
1451         source = source.rstrip()
1452         expected, _ = read_data("missing_final_newline.diff")
1453         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1454         diff_header = re.compile(
1455             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1456             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1457         )
1458         try:
1459             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1460             self.assertEqual(result.exit_code, 0)
1461         finally:
1462             os.unlink(tmp_file)
1463         actual = result.output
1464         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1465         self.assertEqual(actual, expected)
1466
1467     @staticmethod
1468     def compare_results(
1469         result: click.testing.Result, expected_value: str, expected_exit_code: int
1470     ) -> None:
1471         """Helper method to test the value and exit code of a click Result."""
1472         assert (
1473             result.output == expected_value
1474         ), "The output did not match the expected value."
1475         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1476
1477     def test_code_option(self) -> None:
1478         """Test the code option with no changes."""
1479         code = 'print("Hello world")\n'
1480         args = ["--code", code]
1481         result = CliRunner().invoke(black.main, args)
1482
1483         self.compare_results(result, code, 0)
1484
1485     def test_code_option_changed(self) -> None:
1486         """Test the code option when changes are required."""
1487         code = "print('hello world')"
1488         formatted = black.format_str(code, mode=DEFAULT_MODE)
1489
1490         args = ["--code", code]
1491         result = CliRunner().invoke(black.main, args)
1492
1493         self.compare_results(result, formatted, 0)
1494
1495     def test_code_option_check(self) -> None:
1496         """Test the code option when check is passed."""
1497         args = ["--check", "--code", 'print("Hello world")\n']
1498         result = CliRunner().invoke(black.main, args)
1499         self.compare_results(result, "", 0)
1500
1501     def test_code_option_check_changed(self) -> None:
1502         """Test the code option when changes are required, and check is passed."""
1503         args = ["--check", "--code", "print('hello world')"]
1504         result = CliRunner().invoke(black.main, args)
1505         self.compare_results(result, "", 1)
1506
1507     def test_code_option_diff(self) -> None:
1508         """Test the code option when diff is passed."""
1509         code = "print('hello world')"
1510         formatted = black.format_str(code, mode=DEFAULT_MODE)
1511         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1512
1513         args = ["--diff", "--code", code]
1514         result = CliRunner().invoke(black.main, args)
1515
1516         # Remove time from diff
1517         output = DIFF_TIME.sub("", result.output)
1518
1519         assert output == result_diff, "The output did not match the expected value."
1520         assert result.exit_code == 0, "The exit code is incorrect."
1521
1522     def test_code_option_color_diff(self) -> None:
1523         """Test the code option when color and diff are passed."""
1524         code = "print('hello world')"
1525         formatted = black.format_str(code, mode=DEFAULT_MODE)
1526
1527         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1528         result_diff = color_diff(result_diff)
1529
1530         args = ["--diff", "--color", "--code", code]
1531         result = CliRunner().invoke(black.main, args)
1532
1533         # Remove time from diff
1534         output = DIFF_TIME.sub("", result.output)
1535
1536         assert output == result_diff, "The output did not match the expected value."
1537         assert result.exit_code == 0, "The exit code is incorrect."
1538
1539     @pytest.mark.incompatible_with_mypyc
1540     def test_code_option_safe(self) -> None:
1541         """Test that the code option throws an error when the sanity checks fail."""
1542         # Patch black.assert_equivalent to ensure the sanity checks fail
1543         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1544             code = 'print("Hello world")'
1545             error_msg = f"{code}\nerror: cannot format <string>: \n"
1546
1547             args = ["--safe", "--code", code]
1548             result = CliRunner().invoke(black.main, args)
1549
1550             self.compare_results(result, error_msg, 123)
1551
1552     def test_code_option_fast(self) -> None:
1553         """Test that the code option ignores errors when the sanity checks fail."""
1554         # Patch black.assert_equivalent to ensure the sanity checks fail
1555         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1556             code = 'print("Hello world")'
1557             formatted = black.format_str(code, mode=DEFAULT_MODE)
1558
1559             args = ["--fast", "--code", code]
1560             result = CliRunner().invoke(black.main, args)
1561
1562             self.compare_results(result, formatted, 0)
1563
1564     @pytest.mark.incompatible_with_mypyc
1565     def test_code_option_config(self) -> None:
1566         """
1567         Test that the code option finds the pyproject.toml in the current directory.
1568         """
1569         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1570             args = ["--code", "print"]
1571             # This is the only directory known to contain a pyproject.toml
1572             with change_directory(PROJECT_ROOT):
1573                 CliRunner().invoke(black.main, args)
1574                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1575
1576             assert (
1577                 len(parse.mock_calls) >= 1
1578             ), "Expected config parse to be called with the current directory."
1579
1580             _, call_args, _ = parse.mock_calls[0]
1581             assert (
1582                 call_args[0].lower() == str(pyproject_path).lower()
1583             ), "Incorrect config loaded."
1584
1585     @pytest.mark.incompatible_with_mypyc
1586     def test_code_option_parent_config(self) -> None:
1587         """
1588         Test that the code option finds the pyproject.toml in the parent directory.
1589         """
1590         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1591             with change_directory(THIS_DIR):
1592                 args = ["--code", "print"]
1593                 CliRunner().invoke(black.main, args)
1594
1595                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1596                 assert (
1597                     len(parse.mock_calls) >= 1
1598                 ), "Expected config parse to be called with the current directory."
1599
1600                 _, call_args, _ = parse.mock_calls[0]
1601                 assert (
1602                     call_args[0].lower() == str(pyproject_path).lower()
1603                 ), "Incorrect config loaded."
1604
1605     def test_for_handled_unexpected_eof_error(self) -> None:
1606         """
1607         Test that an unexpected EOF SyntaxError is nicely presented.
1608         """
1609         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1610             black.lib2to3_parse("print(", {})
1611
1612         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1613
1614     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1615         with pytest.raises(AssertionError) as err:
1616             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1617
1618         err.match("--safe")
1619         # Unfortunately the SyntaxError message has changed in newer versions so we
1620         # can't match it directly.
1621         err.match("invalid character")
1622         err.match(r"\(<unknown>, line 1\)")
1623
1624
1625 class TestCaching:
1626     def test_get_cache_dir(
1627         self,
1628         tmp_path: Path,
1629         monkeypatch: pytest.MonkeyPatch,
1630     ) -> None:
1631         # Create multiple cache directories
1632         workspace1 = tmp_path / "ws1"
1633         workspace1.mkdir()
1634         workspace2 = tmp_path / "ws2"
1635         workspace2.mkdir()
1636
1637         # Force user_cache_dir to use the temporary directory for easier assertions
1638         patch_user_cache_dir = patch(
1639             target="black.cache.user_cache_dir",
1640             autospec=True,
1641             return_value=str(workspace1),
1642         )
1643
1644         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1645         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1646         with patch_user_cache_dir:
1647             assert get_cache_dir() == workspace1
1648
1649         # If it is set, use the path provided in the env var.
1650         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1651         assert get_cache_dir() == workspace2
1652
1653     def test_cache_broken_file(self) -> None:
1654         mode = DEFAULT_MODE
1655         with cache_dir() as workspace:
1656             cache_file = get_cache_file(mode)
1657             cache_file.write_text("this is not a pickle")
1658             assert black.read_cache(mode) == {}
1659             src = (workspace / "test.py").resolve()
1660             src.write_text("print('hello')")
1661             invokeBlack([str(src)])
1662             cache = black.read_cache(mode)
1663             assert str(src) in cache
1664
1665     def test_cache_single_file_already_cached(self) -> None:
1666         mode = DEFAULT_MODE
1667         with cache_dir() as workspace:
1668             src = (workspace / "test.py").resolve()
1669             src.write_text("print('hello')")
1670             black.write_cache({}, [src], mode)
1671             invokeBlack([str(src)])
1672             assert src.read_text() == "print('hello')"
1673
1674     @event_loop()
1675     def test_cache_multiple_files(self) -> None:
1676         mode = DEFAULT_MODE
1677         with cache_dir() as workspace, patch(
1678             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1679         ):
1680             one = (workspace / "one.py").resolve()
1681             with one.open("w") as fobj:
1682                 fobj.write("print('hello')")
1683             two = (workspace / "two.py").resolve()
1684             with two.open("w") as fobj:
1685                 fobj.write("print('hello')")
1686             black.write_cache({}, [one], mode)
1687             invokeBlack([str(workspace)])
1688             with one.open("r") as fobj:
1689                 assert fobj.read() == "print('hello')"
1690             with two.open("r") as fobj:
1691                 assert fobj.read() == 'print("hello")\n'
1692             cache = black.read_cache(mode)
1693             assert str(one) in cache
1694             assert str(two) in cache
1695
1696     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1697     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1698         mode = DEFAULT_MODE
1699         with cache_dir() as workspace:
1700             src = (workspace / "test.py").resolve()
1701             with src.open("w") as fobj:
1702                 fobj.write("print('hello')")
1703             with patch("black.read_cache") as read_cache, patch(
1704                 "black.write_cache"
1705             ) as write_cache:
1706                 cmd = [str(src), "--diff"]
1707                 if color:
1708                     cmd.append("--color")
1709                 invokeBlack(cmd)
1710                 cache_file = get_cache_file(mode)
1711                 assert cache_file.exists() is False
1712                 write_cache.assert_not_called()
1713                 read_cache.assert_not_called()
1714
1715     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1716     @event_loop()
1717     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1718         with cache_dir() as workspace:
1719             for tag in range(0, 4):
1720                 src = (workspace / f"test{tag}.py").resolve()
1721                 with src.open("w") as fobj:
1722                     fobj.write("print('hello')")
1723             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1724                 cmd = ["--diff", str(workspace)]
1725                 if color:
1726                     cmd.append("--color")
1727                 invokeBlack(cmd, exit_code=0)
1728                 # this isn't quite doing what we want, but if it _isn't_
1729                 # called then we cannot be using the lock it provides
1730                 mgr.assert_called()
1731
1732     def test_no_cache_when_stdin(self) -> None:
1733         mode = DEFAULT_MODE
1734         with cache_dir():
1735             result = CliRunner().invoke(
1736                 black.main, ["-"], input=BytesIO(b"print('hello')")
1737             )
1738             assert not result.exit_code
1739             cache_file = get_cache_file(mode)
1740             assert not cache_file.exists()
1741
1742     def test_read_cache_no_cachefile(self) -> None:
1743         mode = DEFAULT_MODE
1744         with cache_dir():
1745             assert black.read_cache(mode) == {}
1746
1747     def test_write_cache_read_cache(self) -> None:
1748         mode = DEFAULT_MODE
1749         with cache_dir() as workspace:
1750             src = (workspace / "test.py").resolve()
1751             src.touch()
1752             black.write_cache({}, [src], mode)
1753             cache = black.read_cache(mode)
1754             assert str(src) in cache
1755             assert cache[str(src)] == black.get_cache_info(src)
1756
1757     def test_filter_cached(self) -> None:
1758         with TemporaryDirectory() as workspace:
1759             path = Path(workspace)
1760             uncached = (path / "uncached").resolve()
1761             cached = (path / "cached").resolve()
1762             cached_but_changed = (path / "changed").resolve()
1763             uncached.touch()
1764             cached.touch()
1765             cached_but_changed.touch()
1766             cache = {
1767                 str(cached): black.get_cache_info(cached),
1768                 str(cached_but_changed): (0.0, 0),
1769             }
1770             todo, done = black.filter_cached(
1771                 cache, {uncached, cached, cached_but_changed}
1772             )
1773             assert todo == {uncached, cached_but_changed}
1774             assert done == {cached}
1775
1776     def test_write_cache_creates_directory_if_needed(self) -> None:
1777         mode = DEFAULT_MODE
1778         with cache_dir(exists=False) as workspace:
1779             assert not workspace.exists()
1780             black.write_cache({}, [], mode)
1781             assert workspace.exists()
1782
1783     @event_loop()
1784     def test_failed_formatting_does_not_get_cached(self) -> None:
1785         mode = DEFAULT_MODE
1786         with cache_dir() as workspace, patch(
1787             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1788         ):
1789             failing = (workspace / "failing.py").resolve()
1790             with failing.open("w") as fobj:
1791                 fobj.write("not actually python")
1792             clean = (workspace / "clean.py").resolve()
1793             with clean.open("w") as fobj:
1794                 fobj.write('print("hello")\n')
1795             invokeBlack([str(workspace)], exit_code=123)
1796             cache = black.read_cache(mode)
1797             assert str(failing) not in cache
1798             assert str(clean) in cache
1799
1800     def test_write_cache_write_fail(self) -> None:
1801         mode = DEFAULT_MODE
1802         with cache_dir(), patch.object(Path, "open") as mock:
1803             mock.side_effect = OSError
1804             black.write_cache({}, [], mode)
1805
1806     def test_read_cache_line_lengths(self) -> None:
1807         mode = DEFAULT_MODE
1808         short_mode = replace(DEFAULT_MODE, line_length=1)
1809         with cache_dir() as workspace:
1810             path = (workspace / "file.py").resolve()
1811             path.touch()
1812             black.write_cache({}, [path], mode)
1813             one = black.read_cache(mode)
1814             assert str(path) in one
1815             two = black.read_cache(short_mode)
1816             assert str(path) not in two
1817
1818
1819 def assert_collected_sources(
1820     src: Sequence[Union[str, Path]],
1821     expected: Sequence[Union[str, Path]],
1822     *,
1823     ctx: Optional[FakeContext] = None,
1824     exclude: Optional[str] = None,
1825     include: Optional[str] = None,
1826     extend_exclude: Optional[str] = None,
1827     force_exclude: Optional[str] = None,
1828     stdin_filename: Optional[str] = None,
1829 ) -> None:
1830     gs_src = tuple(str(Path(s)) for s in src)
1831     gs_expected = [Path(s) for s in expected]
1832     gs_exclude = None if exclude is None else compile_pattern(exclude)
1833     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1834     gs_extend_exclude = (
1835         None if extend_exclude is None else compile_pattern(extend_exclude)
1836     )
1837     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1838     collected = black.get_sources(
1839         ctx=ctx or FakeContext(),
1840         src=gs_src,
1841         quiet=False,
1842         verbose=False,
1843         include=gs_include,
1844         exclude=gs_exclude,
1845         extend_exclude=gs_extend_exclude,
1846         force_exclude=gs_force_exclude,
1847         report=black.Report(),
1848         stdin_filename=stdin_filename,
1849     )
1850     assert sorted(collected) == sorted(gs_expected)
1851
1852
1853 class TestFileCollection:
1854     def test_include_exclude(self) -> None:
1855         path = THIS_DIR / "data" / "include_exclude_tests"
1856         src = [path]
1857         expected = [
1858             Path(path / "b/dont_exclude/a.py"),
1859             Path(path / "b/dont_exclude/a.pyi"),
1860         ]
1861         assert_collected_sources(
1862             src,
1863             expected,
1864             include=r"\.pyi?$",
1865             exclude=r"/exclude/|/\.definitely_exclude/",
1866         )
1867
1868     def test_gitignore_used_as_default(self) -> None:
1869         base = Path(DATA_DIR / "include_exclude_tests")
1870         expected = [
1871             base / "b/.definitely_exclude/a.py",
1872             base / "b/.definitely_exclude/a.pyi",
1873         ]
1874         src = [base / "b/"]
1875         ctx = FakeContext()
1876         ctx.obj["root"] = base
1877         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1878
1879     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1880     def test_exclude_for_issue_1572(self) -> None:
1881         # Exclude shouldn't touch files that were explicitly given to Black through the
1882         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1883         # https://github.com/psf/black/issues/1572
1884         path = DATA_DIR / "include_exclude_tests"
1885         src = [path / "b/exclude/a.py"]
1886         expected = [path / "b/exclude/a.py"]
1887         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1888
1889     def test_gitignore_exclude(self) -> None:
1890         path = THIS_DIR / "data" / "include_exclude_tests"
1891         include = re.compile(r"\.pyi?$")
1892         exclude = re.compile(r"")
1893         report = black.Report()
1894         gitignore = PathSpec.from_lines(
1895             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1896         )
1897         sources: List[Path] = []
1898         expected = [
1899             Path(path / "b/dont_exclude/a.py"),
1900             Path(path / "b/dont_exclude/a.pyi"),
1901         ]
1902         this_abs = THIS_DIR.resolve()
1903         sources.extend(
1904             black.gen_python_files(
1905                 path.iterdir(),
1906                 this_abs,
1907                 include,
1908                 exclude,
1909                 None,
1910                 None,
1911                 report,
1912                 gitignore,
1913                 verbose=False,
1914                 quiet=False,
1915             )
1916         )
1917         assert sorted(expected) == sorted(sources)
1918
1919     def test_nested_gitignore(self) -> None:
1920         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1921         include = re.compile(r"\.pyi?$")
1922         exclude = re.compile(r"")
1923         root_gitignore = black.files.get_gitignore(path)
1924         report = black.Report()
1925         expected: List[Path] = [
1926             Path(path / "x.py"),
1927             Path(path / "root/b.py"),
1928             Path(path / "root/c.py"),
1929             Path(path / "root/child/c.py"),
1930         ]
1931         this_abs = THIS_DIR.resolve()
1932         sources = list(
1933             black.gen_python_files(
1934                 path.iterdir(),
1935                 this_abs,
1936                 include,
1937                 exclude,
1938                 None,
1939                 None,
1940                 report,
1941                 root_gitignore,
1942                 verbose=False,
1943                 quiet=False,
1944             )
1945         )
1946         assert sorted(expected) == sorted(sources)
1947
1948     def test_invalid_gitignore(self) -> None:
1949         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1950         empty_config = path / "pyproject.toml"
1951         result = BlackRunner().invoke(
1952             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1953         )
1954         assert result.exit_code == 1
1955         assert result.stderr_bytes is not None
1956
1957         gitignore = path / ".gitignore"
1958         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1959
1960     def test_invalid_nested_gitignore(self) -> None:
1961         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1962         empty_config = path / "pyproject.toml"
1963         result = BlackRunner().invoke(
1964             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1965         )
1966         assert result.exit_code == 1
1967         assert result.stderr_bytes is not None
1968
1969         gitignore = path / "a" / ".gitignore"
1970         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1971
1972     def test_empty_include(self) -> None:
1973         path = DATA_DIR / "include_exclude_tests"
1974         src = [path]
1975         expected = [
1976             Path(path / "b/exclude/a.pie"),
1977             Path(path / "b/exclude/a.py"),
1978             Path(path / "b/exclude/a.pyi"),
1979             Path(path / "b/dont_exclude/a.pie"),
1980             Path(path / "b/dont_exclude/a.py"),
1981             Path(path / "b/dont_exclude/a.pyi"),
1982             Path(path / "b/.definitely_exclude/a.pie"),
1983             Path(path / "b/.definitely_exclude/a.py"),
1984             Path(path / "b/.definitely_exclude/a.pyi"),
1985             Path(path / ".gitignore"),
1986             Path(path / "pyproject.toml"),
1987         ]
1988         # Setting exclude explicitly to an empty string to block .gitignore usage.
1989         assert_collected_sources(src, expected, include="", exclude="")
1990
1991     def test_extend_exclude(self) -> None:
1992         path = DATA_DIR / "include_exclude_tests"
1993         src = [path]
1994         expected = [
1995             Path(path / "b/exclude/a.py"),
1996             Path(path / "b/dont_exclude/a.py"),
1997         ]
1998         assert_collected_sources(
1999             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2000         )
2001
2002     @pytest.mark.incompatible_with_mypyc
2003     def test_symlink_out_of_root_directory(self) -> None:
2004         path = MagicMock()
2005         root = THIS_DIR.resolve()
2006         child = MagicMock()
2007         include = re.compile(black.DEFAULT_INCLUDES)
2008         exclude = re.compile(black.DEFAULT_EXCLUDES)
2009         report = black.Report()
2010         gitignore = PathSpec.from_lines("gitwildmatch", [])
2011         # `child` should behave like a symlink which resolved path is clearly
2012         # outside of the `root` directory.
2013         path.iterdir.return_value = [child]
2014         child.resolve.return_value = Path("/a/b/c")
2015         child.as_posix.return_value = "/a/b/c"
2016         try:
2017             list(
2018                 black.gen_python_files(
2019                     path.iterdir(),
2020                     root,
2021                     include,
2022                     exclude,
2023                     None,
2024                     None,
2025                     report,
2026                     gitignore,
2027                     verbose=False,
2028                     quiet=False,
2029                 )
2030             )
2031         except ValueError as ve:
2032             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2033         path.iterdir.assert_called_once()
2034         child.resolve.assert_called_once()
2035
2036     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2037     def test_get_sources_with_stdin(self) -> None:
2038         src = ["-"]
2039         expected = ["-"]
2040         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2041
2042     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2043     def test_get_sources_with_stdin_filename(self) -> None:
2044         src = ["-"]
2045         stdin_filename = str(THIS_DIR / "data/collections.py")
2046         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2047         assert_collected_sources(
2048             src,
2049             expected,
2050             exclude=r"/exclude/a\.py",
2051             stdin_filename=stdin_filename,
2052         )
2053
2054     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2055     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2056         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2057         # file being passed directly. This is the same as
2058         # test_exclude_for_issue_1572
2059         path = DATA_DIR / "include_exclude_tests"
2060         src = ["-"]
2061         stdin_filename = str(path / "b/exclude/a.py")
2062         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2063         assert_collected_sources(
2064             src,
2065             expected,
2066             exclude=r"/exclude/|a\.py",
2067             stdin_filename=stdin_filename,
2068         )
2069
2070     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2071     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2072         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2073         # file being passed directly. This is the same as
2074         # test_exclude_for_issue_1572
2075         src = ["-"]
2076         path = THIS_DIR / "data" / "include_exclude_tests"
2077         stdin_filename = str(path / "b/exclude/a.py")
2078         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2079         assert_collected_sources(
2080             src,
2081             expected,
2082             extend_exclude=r"/exclude/|a\.py",
2083             stdin_filename=stdin_filename,
2084         )
2085
2086     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2087     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2088         # Force exclude should exclude the file when passing it through
2089         # stdin_filename
2090         path = THIS_DIR / "data" / "include_exclude_tests"
2091         stdin_filename = str(path / "b/exclude/a.py")
2092         assert_collected_sources(
2093             src=["-"],
2094             expected=[],
2095             force_exclude=r"/exclude/|a\.py",
2096             stdin_filename=stdin_filename,
2097         )
2098
2099
2100 try:
2101     with open(black.__file__, "r", encoding="utf-8") as _bf:
2102         black_source_lines = _bf.readlines()
2103 except UnicodeDecodeError:
2104     if not black.COMPILED:
2105         raise
2106
2107
2108 def tracefunc(
2109     frame: types.FrameType, event: str, arg: Any
2110 ) -> Callable[[types.FrameType, str, Any], Any]:
2111     """Show function calls `from black/__init__.py` as they happen.
2112
2113     Register this with `sys.settrace()` in a test you're debugging.
2114     """
2115     if event != "call":
2116         return tracefunc
2117
2118     stack = len(inspect.stack()) - 19
2119     stack *= 2
2120     filename = frame.f_code.co_filename
2121     lineno = frame.f_lineno
2122     func_sig_lineno = lineno - 1
2123     funcname = black_source_lines[func_sig_lineno].strip()
2124     while funcname.startswith("@"):
2125         func_sig_lineno += 1
2126         funcname = black_source_lines[func_sig_lineno].strip()
2127     if "black/__init__.py" in filename:
2128         print(f"{' ' * stack}{lineno}:{funcname}")
2129     return tracefunc