]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

281019a0bfa6d90252b42d630e4b12d429f18109
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
67 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
68 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
69 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
70 T = TypeVar("T")
71 R = TypeVar("R")
72
73 # Match the time output in a diff, but nothing else
74 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.cache.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop() -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     loop = policy.new_event_loop()
91     asyncio.set_event_loop(loop)
92     try:
93         yield
94
95     finally:
96         loop.close()
97
98
99 class FakeContext(click.Context):
100     """A fake click Context for when calling functions that need it."""
101
102     def __init__(self) -> None:
103         self.default_map: Dict[str, Any] = {}
104         # Dummy root, since most of the tests don't care about it
105         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
106
107
108 class FakeParameter(click.Parameter):
109     """A fake click Parameter for when calling functions that need it."""
110
111     def __init__(self) -> None:
112         pass
113
114
115 class BlackRunner(CliRunner):
116     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
117
118     def __init__(self) -> None:
119         super().__init__(mix_stderr=False)
120
121
122 def invokeBlack(
123     args: List[str], exit_code: int = 0, ignore_config: bool = True
124 ) -> None:
125     runner = BlackRunner()
126     if ignore_config:
127         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
128     result = runner.invoke(black.main, args, catch_exceptions=False)
129     assert result.stdout_bytes is not None
130     assert result.stderr_bytes is not None
131     msg = (
132         f"Failed with args: {args}\n"
133         f"stdout: {result.stdout_bytes.decode()!r}\n"
134         f"stderr: {result.stderr_bytes.decode()!r}\n"
135         f"exception: {result.exception}"
136     )
137     assert result.exit_code == exit_code, msg
138
139
140 class BlackTestCase(BlackBaseTestCase):
141     invokeBlack = staticmethod(invokeBlack)
142
143     def test_empty_ff(self) -> None:
144         expected = ""
145         tmp_file = Path(black.dump_to_file())
146         try:
147             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
148             with open(tmp_file, encoding="utf8") as f:
149                 actual = f.read()
150         finally:
151             os.unlink(tmp_file)
152         self.assertFormatEqual(expected, actual)
153
154     def test_experimental_string_processing_warns(self) -> None:
155         self.assertWarns(
156             black.mode.Deprecated, black.Mode, experimental_string_processing=True
157         )
158
159     def test_piping(self) -> None:
160         source, expected = read_data("src/black/__init__", data=False)
161         result = BlackRunner().invoke(
162             black.main,
163             [
164                 "-",
165                 "--fast",
166                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
167                 f"--config={EMPTY_CONFIG}",
168             ],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         if source != result.output:
174             black.assert_equivalent(source, result.output)
175             black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("simple_cases/expression.py")
183         expected, _ = read_data("simple_cases/expression.diff")
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={EMPTY_CONFIG}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("simple_cases/expression.py")
201         args = [
202             "-",
203             "--fast",
204             f"--line-length={black.DEFAULT_LINE_LENGTH}",
205             "--diff",
206             "--color",
207             f"--config={EMPTY_CONFIG}",
208         ]
209         result = BlackRunner().invoke(
210             black.main, args, input=BytesIO(source.encode("utf8"))
211         )
212         actual = result.output
213         # Again, the contents are checked in a different test, so only look for colors.
214         self.assertIn("\033[1m", actual)
215         self.assertIn("\033[36m", actual)
216         self.assertIn("\033[32m", actual)
217         self.assertIn("\033[31m", actual)
218         self.assertIn("\033[0m", actual)
219
220     @patch("black.dump_to_file", dump_to_stderr)
221     def _test_wip(self) -> None:
222         source, expected = read_data("wip")
223         sys.settrace(tracefunc)
224         mode = replace(
225             DEFAULT_MODE,
226             experimental_string_processing=False,
227             target_versions={black.TargetVersion.PY38},
228         )
229         actual = fs(source, mode=mode)
230         sys.settrace(None)
231         self.assertFormatEqual(expected, actual)
232         black.assert_equivalent(source, actual)
233         black.assert_stable(source, actual, black.FileMode())
234
235     def test_pep_572_version_detection(self) -> None:
236         source, _ = read_data("pep_572")
237         root = black.lib2to3_parse(source)
238         features = black.get_features_used(root)
239         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
240         versions = black.detect_target_versions(root)
241         self.assertIn(black.TargetVersion.PY38, versions)
242
243     def test_expression_ff(self) -> None:
244         source, expected = read_data("simple_cases/expression.py")
245         tmp_file = Path(black.dump_to_file(source))
246         try:
247             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
248             with open(tmp_file, encoding="utf8") as f:
249                 actual = f.read()
250         finally:
251             os.unlink(tmp_file)
252         self.assertFormatEqual(expected, actual)
253         with patch("black.dump_to_file", dump_to_stderr):
254             black.assert_equivalent(source, actual)
255             black.assert_stable(source, actual, DEFAULT_MODE)
256
257     def test_expression_diff(self) -> None:
258         source, _ = read_data("simple_cases/expression.py")
259         expected, _ = read_data("simple_cases/expression.diff")
260         tmp_file = Path(black.dump_to_file(source))
261         diff_header = re.compile(
262             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
263             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
264         )
265         try:
266             result = BlackRunner().invoke(
267                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
268             )
269             self.assertEqual(result.exit_code, 0)
270         finally:
271             os.unlink(tmp_file)
272         actual = result.output
273         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
274         if expected != actual:
275             dump = black.dump_to_file(actual)
276             msg = (
277                 "Expected diff isn't equal to the actual. If you made changes to"
278                 " expression.py and this is an anticipated difference, overwrite"
279                 f" tests/data/expression.diff with {dump}"
280             )
281             self.assertEqual(expected, actual, msg)
282
283     def test_expression_diff_with_color(self) -> None:
284         source, _ = read_data("simple_cases/expression.py")
285         expected, _ = read_data("simple_cases/expression.diff")
286         tmp_file = Path(black.dump_to_file(source))
287         try:
288             result = BlackRunner().invoke(
289                 black.main,
290                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
291             )
292         finally:
293             os.unlink(tmp_file)
294         actual = result.output
295         # We check the contents of the diff in `test_expression_diff`. All
296         # we need to check here is that color codes exist in the result.
297         self.assertIn("\033[1m", actual)
298         self.assertIn("\033[36m", actual)
299         self.assertIn("\033[32m", actual)
300         self.assertIn("\033[31m", actual)
301         self.assertIn("\033[0m", actual)
302
303     def test_detect_pos_only_arguments(self) -> None:
304         source, _ = read_data("pep_570")
305         root = black.lib2to3_parse(source)
306         features = black.get_features_used(root)
307         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
308         versions = black.detect_target_versions(root)
309         self.assertIn(black.TargetVersion.PY38, versions)
310
311     @patch("black.dump_to_file", dump_to_stderr)
312     def test_string_quotes(self) -> None:
313         source, expected = read_data("string_quotes")
314         mode = black.Mode(preview=True)
315         assert_format(source, expected, mode)
316         mode = replace(mode, string_normalization=False)
317         not_normalized = fs(source, mode=mode)
318         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
319         black.assert_equivalent(source, not_normalized)
320         black.assert_stable(source, not_normalized, mode=mode)
321
322     def test_skip_magic_trailing_comma(self) -> None:
323         source, _ = read_data("simple_cases/expression.py")
324         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         diff_header = re.compile(
327             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
328             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
329         )
330         try:
331             result = BlackRunner().invoke(
332                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
333             )
334             self.assertEqual(result.exit_code, 0)
335         finally:
336             os.unlink(tmp_file)
337         actual = result.output
338         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
339         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
340         if expected != actual:
341             dump = black.dump_to_file(actual)
342             msg = (
343                 "Expected diff isn't equal to the actual. If you made changes to"
344                 " expression.py and this is an anticipated difference, overwrite"
345                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
346             )
347             self.assertEqual(expected, actual, msg)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_async_as_identifier(self) -> None:
351         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
352         source, expected = read_data("async_as_identifier")
353         actual = fs(source)
354         self.assertFormatEqual(expected, actual)
355         major, minor = sys.version_info[:2]
356         if major < 3 or (major <= 3 and minor < 7):
357             black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, DEFAULT_MODE)
359         # ensure black can parse this when the target is 3.6
360         self.invokeBlack([str(source_path), "--target-version", "py36"])
361         # but not on 3.7, because async/await is no longer an identifier
362         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_python37(self) -> None:
366         source_path = (THIS_DIR / "data" / "python37.py").resolve()
367         source, expected = read_data("python37")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         major, minor = sys.version_info[:2]
371         if major > 3 or (major == 3 and minor >= 7):
372             black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, DEFAULT_MODE)
374         # ensure black can parse this when the target is 3.7
375         self.invokeBlack([str(source_path), "--target-version", "py37"])
376         # but not on 3.6, because we use async as a reserved keyword
377         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
378
379     def test_tab_comment_indentation(self) -> None:
380         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
381         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
382         self.assertFormatEqual(contents_spc, fs(contents_spc))
383         self.assertFormatEqual(contents_spc, fs(contents_tab))
384
385         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
386         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
387         self.assertFormatEqual(contents_spc, fs(contents_spc))
388         self.assertFormatEqual(contents_spc, fs(contents_tab))
389
390         # mixed tabs and spaces (valid Python 2 code)
391         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
392         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
393         self.assertFormatEqual(contents_spc, fs(contents_spc))
394         self.assertFormatEqual(contents_spc, fs(contents_tab))
395
396         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
397         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
398         self.assertFormatEqual(contents_spc, fs(contents_spc))
399         self.assertFormatEqual(contents_spc, fs(contents_tab))
400
401     def test_report_verbose(self) -> None:
402         report = Report(verbose=True)
403         out_lines = []
404         err_lines = []
405
406         def out(msg: str, **kwargs: Any) -> None:
407             out_lines.append(msg)
408
409         def err(msg: str, **kwargs: Any) -> None:
410             err_lines.append(msg)
411
412         with patch("black.output._out", out), patch("black.output._err", err):
413             report.done(Path("f1"), black.Changed.NO)
414             self.assertEqual(len(out_lines), 1)
415             self.assertEqual(len(err_lines), 0)
416             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
417             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
418             self.assertEqual(report.return_code, 0)
419             report.done(Path("f2"), black.Changed.YES)
420             self.assertEqual(len(out_lines), 2)
421             self.assertEqual(len(err_lines), 0)
422             self.assertEqual(out_lines[-1], "reformatted f2")
423             self.assertEqual(
424                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
425             )
426             report.done(Path("f3"), black.Changed.CACHED)
427             self.assertEqual(len(out_lines), 3)
428             self.assertEqual(len(err_lines), 0)
429             self.assertEqual(
430                 out_lines[-1], "f3 wasn't modified on disk since last run."
431             )
432             self.assertEqual(
433                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
434             )
435             self.assertEqual(report.return_code, 0)
436             report.check = True
437             self.assertEqual(report.return_code, 1)
438             report.check = False
439             report.failed(Path("e1"), "boom")
440             self.assertEqual(len(out_lines), 3)
441             self.assertEqual(len(err_lines), 1)
442             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
446                 " reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.done(Path("f3"), black.Changed.YES)
450             self.assertEqual(len(out_lines), 4)
451             self.assertEqual(len(err_lines), 1)
452             self.assertEqual(out_lines[-1], "reformatted f3")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
456                 " reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.failed(Path("e2"), "boom")
460             self.assertEqual(len(out_lines), 4)
461             self.assertEqual(len(err_lines), 2)
462             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
463             self.assertEqual(
464                 unstyle(str(report)),
465                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
466                 " reformat.",
467             )
468             self.assertEqual(report.return_code, 123)
469             report.path_ignored(Path("wat"), "no match")
470             self.assertEqual(len(out_lines), 5)
471             self.assertEqual(len(err_lines), 2)
472             self.assertEqual(out_lines[-1], "wat ignored: no match")
473             self.assertEqual(
474                 unstyle(str(report)),
475                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
476                 " reformat.",
477             )
478             self.assertEqual(report.return_code, 123)
479             report.done(Path("f4"), black.Changed.NO)
480             self.assertEqual(len(out_lines), 6)
481             self.assertEqual(len(err_lines), 2)
482             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
483             self.assertEqual(
484                 unstyle(str(report)),
485                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
486                 " reformat.",
487             )
488             self.assertEqual(report.return_code, 123)
489             report.check = True
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
493                 " would fail to reformat.",
494             )
495             report.check = False
496             report.diff = True
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
500                 " would fail to reformat.",
501             )
502
503     def test_report_quiet(self) -> None:
504         report = Report(quiet=True)
505         out_lines = []
506         err_lines = []
507
508         def out(msg: str, **kwargs: Any) -> None:
509             out_lines.append(msg)
510
511         def err(msg: str, **kwargs: Any) -> None:
512             err_lines.append(msg)
513
514         with patch("black.output._out", out), patch("black.output._err", err):
515             report.done(Path("f1"), black.Changed.NO)
516             self.assertEqual(len(out_lines), 0)
517             self.assertEqual(len(err_lines), 0)
518             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
519             self.assertEqual(report.return_code, 0)
520             report.done(Path("f2"), black.Changed.YES)
521             self.assertEqual(len(out_lines), 0)
522             self.assertEqual(len(err_lines), 0)
523             self.assertEqual(
524                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
525             )
526             report.done(Path("f3"), black.Changed.CACHED)
527             self.assertEqual(len(out_lines), 0)
528             self.assertEqual(len(err_lines), 0)
529             self.assertEqual(
530                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
531             )
532             self.assertEqual(report.return_code, 0)
533             report.check = True
534             self.assertEqual(report.return_code, 1)
535             report.check = False
536             report.failed(Path("e1"), "boom")
537             self.assertEqual(len(out_lines), 0)
538             self.assertEqual(len(err_lines), 1)
539             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
540             self.assertEqual(
541                 unstyle(str(report)),
542                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
543                 " reformat.",
544             )
545             self.assertEqual(report.return_code, 123)
546             report.done(Path("f3"), black.Changed.YES)
547             self.assertEqual(len(out_lines), 0)
548             self.assertEqual(len(err_lines), 1)
549             self.assertEqual(
550                 unstyle(str(report)),
551                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
552                 " reformat.",
553             )
554             self.assertEqual(report.return_code, 123)
555             report.failed(Path("e2"), "boom")
556             self.assertEqual(len(out_lines), 0)
557             self.assertEqual(len(err_lines), 2)
558             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
562                 " reformat.",
563             )
564             self.assertEqual(report.return_code, 123)
565             report.path_ignored(Path("wat"), "no match")
566             self.assertEqual(len(out_lines), 0)
567             self.assertEqual(len(err_lines), 2)
568             self.assertEqual(
569                 unstyle(str(report)),
570                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
571                 " reformat.",
572             )
573             self.assertEqual(report.return_code, 123)
574             report.done(Path("f4"), black.Changed.NO)
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 2)
577             self.assertEqual(
578                 unstyle(str(report)),
579                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
580                 " reformat.",
581             )
582             self.assertEqual(report.return_code, 123)
583             report.check = True
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
587                 " would fail to reformat.",
588             )
589             report.check = False
590             report.diff = True
591             self.assertEqual(
592                 unstyle(str(report)),
593                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
594                 " would fail to reformat.",
595             )
596
597     def test_report_normal(self) -> None:
598         report = black.Report()
599         out_lines = []
600         err_lines = []
601
602         def out(msg: str, **kwargs: Any) -> None:
603             out_lines.append(msg)
604
605         def err(msg: str, **kwargs: Any) -> None:
606             err_lines.append(msg)
607
608         with patch("black.output._out", out), patch("black.output._err", err):
609             report.done(Path("f1"), black.Changed.NO)
610             self.assertEqual(len(out_lines), 0)
611             self.assertEqual(len(err_lines), 0)
612             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
613             self.assertEqual(report.return_code, 0)
614             report.done(Path("f2"), black.Changed.YES)
615             self.assertEqual(len(out_lines), 1)
616             self.assertEqual(len(err_lines), 0)
617             self.assertEqual(out_lines[-1], "reformatted f2")
618             self.assertEqual(
619                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
620             )
621             report.done(Path("f3"), black.Changed.CACHED)
622             self.assertEqual(len(out_lines), 1)
623             self.assertEqual(len(err_lines), 0)
624             self.assertEqual(out_lines[-1], "reformatted f2")
625             self.assertEqual(
626                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
627             )
628             self.assertEqual(report.return_code, 0)
629             report.check = True
630             self.assertEqual(report.return_code, 1)
631             report.check = False
632             report.failed(Path("e1"), "boom")
633             self.assertEqual(len(out_lines), 1)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
639                 " reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.done(Path("f3"), black.Changed.YES)
643             self.assertEqual(len(out_lines), 2)
644             self.assertEqual(len(err_lines), 1)
645             self.assertEqual(out_lines[-1], "reformatted f3")
646             self.assertEqual(
647                 unstyle(str(report)),
648                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
649                 " reformat.",
650             )
651             self.assertEqual(report.return_code, 123)
652             report.failed(Path("e2"), "boom")
653             self.assertEqual(len(out_lines), 2)
654             self.assertEqual(len(err_lines), 2)
655             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
656             self.assertEqual(
657                 unstyle(str(report)),
658                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
659                 " reformat.",
660             )
661             self.assertEqual(report.return_code, 123)
662             report.path_ignored(Path("wat"), "no match")
663             self.assertEqual(len(out_lines), 2)
664             self.assertEqual(len(err_lines), 2)
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
668                 " reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.done(Path("f4"), black.Changed.NO)
672             self.assertEqual(len(out_lines), 2)
673             self.assertEqual(len(err_lines), 2)
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
677                 " reformat.",
678             )
679             self.assertEqual(report.return_code, 123)
680             report.check = True
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
684                 " would fail to reformat.",
685             )
686             report.check = False
687             report.diff = True
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
691                 " would fail to reformat.",
692             )
693
694     def test_lib2to3_parse(self) -> None:
695         with self.assertRaises(black.InvalidInput):
696             black.lib2to3_parse("invalid syntax")
697
698         straddling = "x + y"
699         black.lib2to3_parse(straddling)
700         black.lib2to3_parse(straddling, {TargetVersion.PY36})
701
702         py2_only = "print x"
703         with self.assertRaises(black.InvalidInput):
704             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
705
706         py3_only = "exec(x, end=y)"
707         black.lib2to3_parse(py3_only)
708         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
709
710     def test_get_features_used_decorator(self) -> None:
711         # Test the feature detection of new decorator syntax
712         # since this makes some test cases of test_get_features_used()
713         # fails if it fails, this is tested first so that a useful case
714         # is identified
715         simples, relaxed = read_data("decorators")
716         # skip explanation comments at the top of the file
717         for simple_test in simples.split("##")[1:]:
718             node = black.lib2to3_parse(simple_test)
719             decorator = str(node.children[0].children[0]).strip()
720             self.assertNotIn(
721                 Feature.RELAXED_DECORATORS,
722                 black.get_features_used(node),
723                 msg=(
724                     f"decorator '{decorator}' follows python<=3.8 syntax"
725                     "but is detected as 3.9+"
726                     # f"The full node is\n{node!r}"
727                 ),
728             )
729         # skip the '# output' comment at the top of the output part
730         for relaxed_test in relaxed.split("##")[1:]:
731             node = black.lib2to3_parse(relaxed_test)
732             decorator = str(node.children[0].children[0]).strip()
733             self.assertIn(
734                 Feature.RELAXED_DECORATORS,
735                 black.get_features_used(node),
736                 msg=(
737                     f"decorator '{decorator}' uses python3.9+ syntax"
738                     "but is detected as python<=3.8"
739                     # f"The full node is\n{node!r}"
740                 ),
741             )
742
743     def test_get_features_used(self) -> None:
744         node = black.lib2to3_parse("def f(*, arg): ...\n")
745         self.assertEqual(black.get_features_used(node), set())
746         node = black.lib2to3_parse("def f(*, arg,): ...\n")
747         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
748         node = black.lib2to3_parse("f(*arg,)\n")
749         self.assertEqual(
750             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
751         )
752         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
753         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
754         node = black.lib2to3_parse("123_456\n")
755         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
756         node = black.lib2to3_parse("123456\n")
757         self.assertEqual(black.get_features_used(node), set())
758         source, expected = read_data("simple_cases/function.py")
759         node = black.lib2to3_parse(source)
760         expected_features = {
761             Feature.TRAILING_COMMA_IN_CALL,
762             Feature.TRAILING_COMMA_IN_DEF,
763             Feature.F_STRINGS,
764         }
765         self.assertEqual(black.get_features_used(node), expected_features)
766         node = black.lib2to3_parse(expected)
767         self.assertEqual(black.get_features_used(node), expected_features)
768         source, expected = read_data("simple_cases/expression.py")
769         node = black.lib2to3_parse(source)
770         self.assertEqual(black.get_features_used(node), set())
771         node = black.lib2to3_parse(expected)
772         self.assertEqual(black.get_features_used(node), set())
773         node = black.lib2to3_parse("lambda a, /, b: ...")
774         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
775         node = black.lib2to3_parse("def fn(a, /, b): ...")
776         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
777         node = black.lib2to3_parse("def fn(): yield a, b")
778         self.assertEqual(black.get_features_used(node), set())
779         node = black.lib2to3_parse("def fn(): return a, b")
780         self.assertEqual(black.get_features_used(node), set())
781         node = black.lib2to3_parse("def fn(): yield *b, c")
782         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
783         node = black.lib2to3_parse("def fn(): return a, *b, c")
784         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
785         node = black.lib2to3_parse("x = a, *b, c")
786         self.assertEqual(black.get_features_used(node), set())
787         node = black.lib2to3_parse("x: Any = regular")
788         self.assertEqual(black.get_features_used(node), set())
789         node = black.lib2to3_parse("x: Any = (regular, regular)")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
794         self.assertEqual(
795             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
796         )
797         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
798         self.assertEqual(black.get_features_used(node), set())
799         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
800         self.assertEqual(black.get_features_used(node), set())
801         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
802         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
803
804     def test_get_features_used_for_future_flags(self) -> None:
805         for src, features in [
806             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
807             (
808                 "from __future__ import (other, annotations)",
809                 {Feature.FUTURE_ANNOTATIONS},
810             ),
811             ("a = 1 + 2\nfrom something import annotations", set()),
812             ("from __future__ import x, y", set()),
813         ]:
814             with self.subTest(src=src, features=features):
815                 node = black.lib2to3_parse(src)
816                 future_imports = black.get_future_imports(node)
817                 self.assertEqual(
818                     black.get_features_used(node, future_imports=future_imports),
819                     features,
820                 )
821
822     def test_get_future_imports(self) -> None:
823         node = black.lib2to3_parse("\n")
824         self.assertEqual(set(), black.get_future_imports(node))
825         node = black.lib2to3_parse("from __future__ import black\n")
826         self.assertEqual({"black"}, black.get_future_imports(node))
827         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
828         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
829         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
830         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
831         node = black.lib2to3_parse(
832             "from __future__ import multiple\nfrom __future__ import imports\n"
833         )
834         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
835         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
836         self.assertEqual({"black"}, black.get_future_imports(node))
837         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
838         self.assertEqual({"black"}, black.get_future_imports(node))
839         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
840         self.assertEqual(set(), black.get_future_imports(node))
841         node = black.lib2to3_parse("from some.module import black\n")
842         self.assertEqual(set(), black.get_future_imports(node))
843         node = black.lib2to3_parse(
844             "from __future__ import unicode_literals as _unicode_literals"
845         )
846         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
847         node = black.lib2to3_parse(
848             "from __future__ import unicode_literals as _lol, print"
849         )
850         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
851
852     @pytest.mark.incompatible_with_mypyc
853     def test_debug_visitor(self) -> None:
854         source, _ = read_data("debug_visitor.py")
855         expected, _ = read_data("debug_visitor.out")
856         out_lines = []
857         err_lines = []
858
859         def out(msg: str, **kwargs: Any) -> None:
860             out_lines.append(msg)
861
862         def err(msg: str, **kwargs: Any) -> None:
863             err_lines.append(msg)
864
865         with patch("black.debug.out", out):
866             DebugVisitor.show(source)
867         actual = "\n".join(out_lines) + "\n"
868         log_name = ""
869         if expected != actual:
870             log_name = black.dump_to_file(*out_lines)
871         self.assertEqual(
872             expected,
873             actual,
874             f"AST print out is different. Actual version dumped to {log_name}",
875         )
876
877     def test_format_file_contents(self) -> None:
878         empty = ""
879         mode = DEFAULT_MODE
880         with self.assertRaises(black.NothingChanged):
881             black.format_file_contents(empty, mode=mode, fast=False)
882         just_nl = "\n"
883         with self.assertRaises(black.NothingChanged):
884             black.format_file_contents(just_nl, mode=mode, fast=False)
885         same = "j = [1, 2, 3]\n"
886         with self.assertRaises(black.NothingChanged):
887             black.format_file_contents(same, mode=mode, fast=False)
888         different = "j = [1,2,3]"
889         expected = same
890         actual = black.format_file_contents(different, mode=mode, fast=False)
891         self.assertEqual(expected, actual)
892         invalid = "return if you can"
893         with self.assertRaises(black.InvalidInput) as e:
894             black.format_file_contents(invalid, mode=mode, fast=False)
895         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
896
897     def test_endmarker(self) -> None:
898         n = black.lib2to3_parse("\n")
899         self.assertEqual(n.type, black.syms.file_input)
900         self.assertEqual(len(n.children), 1)
901         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
902
903     @pytest.mark.incompatible_with_mypyc
904     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
905     def test_assertFormatEqual(self) -> None:
906         out_lines = []
907         err_lines = []
908
909         def out(msg: str, **kwargs: Any) -> None:
910             out_lines.append(msg)
911
912         def err(msg: str, **kwargs: Any) -> None:
913             err_lines.append(msg)
914
915         with patch("black.output._out", out), patch("black.output._err", err):
916             with self.assertRaises(AssertionError):
917                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
918
919         out_str = "".join(out_lines)
920         self.assertIn("Expected tree:", out_str)
921         self.assertIn("Actual tree:", out_str)
922         self.assertEqual("".join(err_lines), "")
923
924     @event_loop()
925     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
926     def test_works_in_mono_process_only_environment(self) -> None:
927         with cache_dir() as workspace:
928             for f in [
929                 (workspace / "one.py").resolve(),
930                 (workspace / "two.py").resolve(),
931             ]:
932                 f.write_text('print("hello")\n')
933             self.invokeBlack([str(workspace)])
934
935     @event_loop()
936     def test_check_diff_use_together(self) -> None:
937         with cache_dir():
938             # Files which will be reformatted.
939             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
940             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
941             # Files which will not be reformatted.
942             src2 = (THIS_DIR / "data" / "simple_cases" / "composition.py").resolve()
943             self.invokeBlack([str(src2), "--diff", "--check"])
944             # Multi file command.
945             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
946
947     def test_no_src_fails(self) -> None:
948         with cache_dir():
949             self.invokeBlack([], exit_code=1)
950
951     def test_src_and_code_fails(self) -> None:
952         with cache_dir():
953             self.invokeBlack([".", "-c", "0"], exit_code=1)
954
955     def test_broken_symlink(self) -> None:
956         with cache_dir() as workspace:
957             symlink = workspace / "broken_link.py"
958             try:
959                 symlink.symlink_to("nonexistent.py")
960             except (OSError, NotImplementedError) as e:
961                 self.skipTest(f"Can't create symlinks: {e}")
962             self.invokeBlack([str(workspace.resolve())])
963
964     def test_single_file_force_pyi(self) -> None:
965         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
966         contents, expected = read_data("force_pyi")
967         with cache_dir() as workspace:
968             path = (workspace / "file.py").resolve()
969             with open(path, "w") as fh:
970                 fh.write(contents)
971             self.invokeBlack([str(path), "--pyi"])
972             with open(path, "r") as fh:
973                 actual = fh.read()
974             # verify cache with --pyi is separate
975             pyi_cache = black.read_cache(pyi_mode)
976             self.assertIn(str(path), pyi_cache)
977             normal_cache = black.read_cache(DEFAULT_MODE)
978             self.assertNotIn(str(path), normal_cache)
979         self.assertFormatEqual(expected, actual)
980         black.assert_equivalent(contents, actual)
981         black.assert_stable(contents, actual, pyi_mode)
982
983     @event_loop()
984     def test_multi_file_force_pyi(self) -> None:
985         reg_mode = DEFAULT_MODE
986         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
987         contents, expected = read_data("force_pyi")
988         with cache_dir() as workspace:
989             paths = [
990                 (workspace / "file1.py").resolve(),
991                 (workspace / "file2.py").resolve(),
992             ]
993             for path in paths:
994                 with open(path, "w") as fh:
995                     fh.write(contents)
996             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
997             for path in paths:
998                 with open(path, "r") as fh:
999                     actual = fh.read()
1000                 self.assertEqual(actual, expected)
1001             # verify cache with --pyi is separate
1002             pyi_cache = black.read_cache(pyi_mode)
1003             normal_cache = black.read_cache(reg_mode)
1004             for path in paths:
1005                 self.assertIn(str(path), pyi_cache)
1006                 self.assertNotIn(str(path), normal_cache)
1007
1008     def test_pipe_force_pyi(self) -> None:
1009         source, expected = read_data("force_pyi")
1010         result = CliRunner().invoke(
1011             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1012         )
1013         self.assertEqual(result.exit_code, 0)
1014         actual = result.output
1015         self.assertFormatEqual(actual, expected)
1016
1017     def test_single_file_force_py36(self) -> None:
1018         reg_mode = DEFAULT_MODE
1019         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1020         source, expected = read_data("force_py36")
1021         with cache_dir() as workspace:
1022             path = (workspace / "file.py").resolve()
1023             with open(path, "w") as fh:
1024                 fh.write(source)
1025             self.invokeBlack([str(path), *PY36_ARGS])
1026             with open(path, "r") as fh:
1027                 actual = fh.read()
1028             # verify cache with --target-version is separate
1029             py36_cache = black.read_cache(py36_mode)
1030             self.assertIn(str(path), py36_cache)
1031             normal_cache = black.read_cache(reg_mode)
1032             self.assertNotIn(str(path), normal_cache)
1033         self.assertEqual(actual, expected)
1034
1035     @event_loop()
1036     def test_multi_file_force_py36(self) -> None:
1037         reg_mode = DEFAULT_MODE
1038         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1039         source, expected = read_data("force_py36")
1040         with cache_dir() as workspace:
1041             paths = [
1042                 (workspace / "file1.py").resolve(),
1043                 (workspace / "file2.py").resolve(),
1044             ]
1045             for path in paths:
1046                 with open(path, "w") as fh:
1047                     fh.write(source)
1048             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1049             for path in paths:
1050                 with open(path, "r") as fh:
1051                     actual = fh.read()
1052                 self.assertEqual(actual, expected)
1053             # verify cache with --target-version is separate
1054             pyi_cache = black.read_cache(py36_mode)
1055             normal_cache = black.read_cache(reg_mode)
1056             for path in paths:
1057                 self.assertIn(str(path), pyi_cache)
1058                 self.assertNotIn(str(path), normal_cache)
1059
1060     def test_pipe_force_py36(self) -> None:
1061         source, expected = read_data("force_py36")
1062         result = CliRunner().invoke(
1063             black.main,
1064             ["-", "-q", "--target-version=py36"],
1065             input=BytesIO(source.encode("utf8")),
1066         )
1067         self.assertEqual(result.exit_code, 0)
1068         actual = result.output
1069         self.assertFormatEqual(actual, expected)
1070
1071     @pytest.mark.incompatible_with_mypyc
1072     def test_reformat_one_with_stdin(self) -> None:
1073         with patch(
1074             "black.format_stdin_to_stdout",
1075             return_value=lambda *args, **kwargs: black.Changed.YES,
1076         ) as fsts:
1077             report = MagicMock()
1078             path = Path("-")
1079             black.reformat_one(
1080                 path,
1081                 fast=True,
1082                 write_back=black.WriteBack.YES,
1083                 mode=DEFAULT_MODE,
1084                 report=report,
1085             )
1086             fsts.assert_called_once()
1087             report.done.assert_called_with(path, black.Changed.YES)
1088
1089     @pytest.mark.incompatible_with_mypyc
1090     def test_reformat_one_with_stdin_filename(self) -> None:
1091         with patch(
1092             "black.format_stdin_to_stdout",
1093             return_value=lambda *args, **kwargs: black.Changed.YES,
1094         ) as fsts:
1095             report = MagicMock()
1096             p = "foo.py"
1097             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1098             expected = Path(p)
1099             black.reformat_one(
1100                 path,
1101                 fast=True,
1102                 write_back=black.WriteBack.YES,
1103                 mode=DEFAULT_MODE,
1104                 report=report,
1105             )
1106             fsts.assert_called_once_with(
1107                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1108             )
1109             # __BLACK_STDIN_FILENAME__ should have been stripped
1110             report.done.assert_called_with(expected, black.Changed.YES)
1111
1112     @pytest.mark.incompatible_with_mypyc
1113     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1114         with patch(
1115             "black.format_stdin_to_stdout",
1116             return_value=lambda *args, **kwargs: black.Changed.YES,
1117         ) as fsts:
1118             report = MagicMock()
1119             p = "foo.pyi"
1120             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1121             expected = Path(p)
1122             black.reformat_one(
1123                 path,
1124                 fast=True,
1125                 write_back=black.WriteBack.YES,
1126                 mode=DEFAULT_MODE,
1127                 report=report,
1128             )
1129             fsts.assert_called_once_with(
1130                 fast=True,
1131                 write_back=black.WriteBack.YES,
1132                 mode=replace(DEFAULT_MODE, is_pyi=True),
1133             )
1134             # __BLACK_STDIN_FILENAME__ should have been stripped
1135             report.done.assert_called_with(expected, black.Changed.YES)
1136
1137     @pytest.mark.incompatible_with_mypyc
1138     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1139         with patch(
1140             "black.format_stdin_to_stdout",
1141             return_value=lambda *args, **kwargs: black.Changed.YES,
1142         ) as fsts:
1143             report = MagicMock()
1144             p = "foo.ipynb"
1145             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1146             expected = Path(p)
1147             black.reformat_one(
1148                 path,
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=DEFAULT_MODE,
1152                 report=report,
1153             )
1154             fsts.assert_called_once_with(
1155                 fast=True,
1156                 write_back=black.WriteBack.YES,
1157                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1158             )
1159             # __BLACK_STDIN_FILENAME__ should have been stripped
1160             report.done.assert_called_with(expected, black.Changed.YES)
1161
1162     @pytest.mark.incompatible_with_mypyc
1163     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1164         with patch(
1165             "black.format_stdin_to_stdout",
1166             return_value=lambda *args, **kwargs: black.Changed.YES,
1167         ) as fsts:
1168             report = MagicMock()
1169             # Even with an existing file, since we are forcing stdin, black
1170             # should output to stdout and not modify the file inplace
1171             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1172             # Make sure is_file actually returns True
1173             self.assertTrue(p.is_file())
1174             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1175             expected = Path(p)
1176             black.reformat_one(
1177                 path,
1178                 fast=True,
1179                 write_back=black.WriteBack.YES,
1180                 mode=DEFAULT_MODE,
1181                 report=report,
1182             )
1183             fsts.assert_called_once()
1184             # __BLACK_STDIN_FILENAME__ should have been stripped
1185             report.done.assert_called_with(expected, black.Changed.YES)
1186
1187     def test_reformat_one_with_stdin_empty(self) -> None:
1188         output = io.StringIO()
1189         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1190             try:
1191                 black.format_stdin_to_stdout(
1192                     fast=True,
1193                     content="",
1194                     write_back=black.WriteBack.YES,
1195                     mode=DEFAULT_MODE,
1196                 )
1197             except io.UnsupportedOperation:
1198                 pass  # StringIO does not support detach
1199             assert output.getvalue() == ""
1200
1201     def test_invalid_cli_regex(self) -> None:
1202         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1203             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1204
1205     def test_required_version_matches_version(self) -> None:
1206         self.invokeBlack(
1207             ["--required-version", black.__version__, "-c", "0"],
1208             exit_code=0,
1209             ignore_config=True,
1210         )
1211
1212     def test_required_version_matches_partial_version(self) -> None:
1213         self.invokeBlack(
1214             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1215             exit_code=0,
1216             ignore_config=True,
1217         )
1218
1219     def test_required_version_does_not_match_on_minor_version(self) -> None:
1220         self.invokeBlack(
1221             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1222             exit_code=1,
1223             ignore_config=True,
1224         )
1225
1226     def test_required_version_does_not_match_version(self) -> None:
1227         result = BlackRunner().invoke(
1228             black.main,
1229             ["--required-version", "20.99b", "-c", "0"],
1230         )
1231         self.assertEqual(result.exit_code, 1)
1232         self.assertIn("required version", result.stderr)
1233
1234     def test_preserves_line_endings(self) -> None:
1235         with TemporaryDirectory() as workspace:
1236             test_file = Path(workspace) / "test.py"
1237             for nl in ["\n", "\r\n"]:
1238                 contents = nl.join(["def f(  ):", "    pass"])
1239                 test_file.write_bytes(contents.encode())
1240                 ff(test_file, write_back=black.WriteBack.YES)
1241                 updated_contents: bytes = test_file.read_bytes()
1242                 self.assertIn(nl.encode(), updated_contents)
1243                 if nl == "\n":
1244                     self.assertNotIn(b"\r\n", updated_contents)
1245
1246     def test_preserves_line_endings_via_stdin(self) -> None:
1247         for nl in ["\n", "\r\n"]:
1248             contents = nl.join(["def f(  ):", "    pass"])
1249             runner = BlackRunner()
1250             result = runner.invoke(
1251                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1252             )
1253             self.assertEqual(result.exit_code, 0)
1254             output = result.stdout_bytes
1255             self.assertIn(nl.encode("utf8"), output)
1256             if nl == "\n":
1257                 self.assertNotIn(b"\r\n", output)
1258
1259     def test_assert_equivalent_different_asts(self) -> None:
1260         with self.assertRaises(AssertionError):
1261             black.assert_equivalent("{}", "None")
1262
1263     def test_shhh_click(self) -> None:
1264         try:
1265             from click import _unicodefun  # type: ignore
1266         except ImportError:
1267             self.skipTest("Incompatible Click version")
1268
1269         if not hasattr(_unicodefun, "_verify_python_env"):
1270             self.skipTest("Incompatible Click version")
1271
1272         # First, let's see if Click is crashing with a preferred ASCII charset.
1273         with patch("locale.getpreferredencoding") as gpe:
1274             gpe.return_value = "ASCII"
1275             with self.assertRaises(RuntimeError):
1276                 _unicodefun._verify_python_env()
1277         # Now, let's silence Click...
1278         black.patch_click()
1279         # ...and confirm it's silent.
1280         with patch("locale.getpreferredencoding") as gpe:
1281             gpe.return_value = "ASCII"
1282             try:
1283                 _unicodefun._verify_python_env()
1284             except RuntimeError as re:
1285                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1286
1287     def test_root_logger_not_used_directly(self) -> None:
1288         def fail(*args: Any, **kwargs: Any) -> None:
1289             self.fail("Record created with root logger")
1290
1291         with patch.multiple(
1292             logging.root,
1293             debug=fail,
1294             info=fail,
1295             warning=fail,
1296             error=fail,
1297             critical=fail,
1298             log=fail,
1299         ):
1300             ff(THIS_DIR / "util.py")
1301
1302     def test_invalid_config_return_code(self) -> None:
1303         tmp_file = Path(black.dump_to_file())
1304         try:
1305             tmp_config = Path(black.dump_to_file())
1306             tmp_config.unlink()
1307             args = ["--config", str(tmp_config), str(tmp_file)]
1308             self.invokeBlack(args, exit_code=2, ignore_config=False)
1309         finally:
1310             tmp_file.unlink()
1311
1312     def test_parse_pyproject_toml(self) -> None:
1313         test_toml_file = THIS_DIR / "test.toml"
1314         config = black.parse_pyproject_toml(str(test_toml_file))
1315         self.assertEqual(config["verbose"], 1)
1316         self.assertEqual(config["check"], "no")
1317         self.assertEqual(config["diff"], "y")
1318         self.assertEqual(config["color"], True)
1319         self.assertEqual(config["line_length"], 79)
1320         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1321         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1322         self.assertEqual(config["exclude"], r"\.pyi?$")
1323         self.assertEqual(config["include"], r"\.py?$")
1324
1325     def test_read_pyproject_toml(self) -> None:
1326         test_toml_file = THIS_DIR / "test.toml"
1327         fake_ctx = FakeContext()
1328         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1329         config = fake_ctx.default_map
1330         self.assertEqual(config["verbose"], "1")
1331         self.assertEqual(config["check"], "no")
1332         self.assertEqual(config["diff"], "y")
1333         self.assertEqual(config["color"], "True")
1334         self.assertEqual(config["line_length"], "79")
1335         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1336         self.assertEqual(config["exclude"], r"\.pyi?$")
1337         self.assertEqual(config["include"], r"\.py?$")
1338
1339     @pytest.mark.incompatible_with_mypyc
1340     def test_find_project_root(self) -> None:
1341         with TemporaryDirectory() as workspace:
1342             root = Path(workspace)
1343             test_dir = root / "test"
1344             test_dir.mkdir()
1345
1346             src_dir = root / "src"
1347             src_dir.mkdir()
1348
1349             root_pyproject = root / "pyproject.toml"
1350             root_pyproject.touch()
1351             src_pyproject = src_dir / "pyproject.toml"
1352             src_pyproject.touch()
1353             src_python = src_dir / "foo.py"
1354             src_python.touch()
1355
1356             self.assertEqual(
1357                 black.find_project_root((src_dir, test_dir)),
1358                 (root.resolve(), "pyproject.toml"),
1359             )
1360             self.assertEqual(
1361                 black.find_project_root((src_dir,)),
1362                 (src_dir.resolve(), "pyproject.toml"),
1363             )
1364             self.assertEqual(
1365                 black.find_project_root((src_python,)),
1366                 (src_dir.resolve(), "pyproject.toml"),
1367             )
1368
1369     @patch(
1370         "black.files.find_user_pyproject_toml",
1371     )
1372     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1373         find_user_pyproject_toml.side_effect = RuntimeError()
1374
1375         with redirect_stderr(io.StringIO()) as stderr:
1376             result = black.files.find_pyproject_toml(
1377                 path_search_start=(str(Path.cwd().root),)
1378             )
1379
1380         assert result is None
1381         err = stderr.getvalue()
1382         assert "Ignoring user configuration" in err
1383
1384     @patch(
1385         "black.files.find_user_pyproject_toml",
1386         black.files.find_user_pyproject_toml.__wrapped__,
1387     )
1388     def test_find_user_pyproject_toml_linux(self) -> None:
1389         if system() == "Windows":
1390             return
1391
1392         # Test if XDG_CONFIG_HOME is checked
1393         with TemporaryDirectory() as workspace:
1394             tmp_user_config = Path(workspace) / "black"
1395             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1396                 self.assertEqual(
1397                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1398                 )
1399
1400         # Test fallback for XDG_CONFIG_HOME
1401         with patch.dict("os.environ"):
1402             os.environ.pop("XDG_CONFIG_HOME", None)
1403             fallback_user_config = Path("~/.config").expanduser() / "black"
1404             self.assertEqual(
1405                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1406             )
1407
1408     def test_find_user_pyproject_toml_windows(self) -> None:
1409         if system() != "Windows":
1410             return
1411
1412         user_config_path = Path.home() / ".black"
1413         self.assertEqual(
1414             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1415         )
1416
1417     def test_bpo_33660_workaround(self) -> None:
1418         if system() == "Windows":
1419             return
1420
1421         # https://bugs.python.org/issue33660
1422         root = Path("/")
1423         with change_directory(root):
1424             path = Path("workspace") / "project"
1425             report = black.Report(verbose=True)
1426             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1427             self.assertEqual(normalized_path, "workspace/project")
1428
1429     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1430         if system() != "Windows":
1431             return
1432
1433         with TemporaryDirectory() as workspace:
1434             root = Path(workspace)
1435             junction_dir = root / "junction"
1436             junction_target_outside_of_root = root / ".."
1437             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1438
1439             report = black.Report(verbose=True)
1440             normalized_path = black.normalize_path_maybe_ignore(
1441                 junction_dir, root, report
1442             )
1443             # Manually delete for Python < 3.8
1444             os.system(f"rmdir {junction_dir}")
1445
1446             self.assertEqual(normalized_path, None)
1447
1448     def test_newline_comment_interaction(self) -> None:
1449         source = "class A:\\\r\n# type: ignore\n pass\n"
1450         output = black.format_str(source, mode=DEFAULT_MODE)
1451         black.assert_stable(source, output, mode=DEFAULT_MODE)
1452
1453     def test_bpo_2142_workaround(self) -> None:
1454
1455         # https://bugs.python.org/issue2142
1456
1457         source, _ = read_data("missing_final_newline.py")
1458         # read_data adds a trailing newline
1459         source = source.rstrip()
1460         expected, _ = read_data("missing_final_newline.diff")
1461         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1462         diff_header = re.compile(
1463             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1464             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1465         )
1466         try:
1467             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1468             self.assertEqual(result.exit_code, 0)
1469         finally:
1470             os.unlink(tmp_file)
1471         actual = result.output
1472         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1473         self.assertEqual(actual, expected)
1474
1475     @staticmethod
1476     def compare_results(
1477         result: click.testing.Result, expected_value: str, expected_exit_code: int
1478     ) -> None:
1479         """Helper method to test the value and exit code of a click Result."""
1480         assert (
1481             result.output == expected_value
1482         ), "The output did not match the expected value."
1483         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1484
1485     def test_code_option(self) -> None:
1486         """Test the code option with no changes."""
1487         code = 'print("Hello world")\n'
1488         args = ["--code", code]
1489         result = CliRunner().invoke(black.main, args)
1490
1491         self.compare_results(result, code, 0)
1492
1493     def test_code_option_changed(self) -> None:
1494         """Test the code option when changes are required."""
1495         code = "print('hello world')"
1496         formatted = black.format_str(code, mode=DEFAULT_MODE)
1497
1498         args = ["--code", code]
1499         result = CliRunner().invoke(black.main, args)
1500
1501         self.compare_results(result, formatted, 0)
1502
1503     def test_code_option_check(self) -> None:
1504         """Test the code option when check is passed."""
1505         args = ["--check", "--code", 'print("Hello world")\n']
1506         result = CliRunner().invoke(black.main, args)
1507         self.compare_results(result, "", 0)
1508
1509     def test_code_option_check_changed(self) -> None:
1510         """Test the code option when changes are required, and check is passed."""
1511         args = ["--check", "--code", "print('hello world')"]
1512         result = CliRunner().invoke(black.main, args)
1513         self.compare_results(result, "", 1)
1514
1515     def test_code_option_diff(self) -> None:
1516         """Test the code option when diff is passed."""
1517         code = "print('hello world')"
1518         formatted = black.format_str(code, mode=DEFAULT_MODE)
1519         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1520
1521         args = ["--diff", "--code", code]
1522         result = CliRunner().invoke(black.main, args)
1523
1524         # Remove time from diff
1525         output = DIFF_TIME.sub("", result.output)
1526
1527         assert output == result_diff, "The output did not match the expected value."
1528         assert result.exit_code == 0, "The exit code is incorrect."
1529
1530     def test_code_option_color_diff(self) -> None:
1531         """Test the code option when color and diff are passed."""
1532         code = "print('hello world')"
1533         formatted = black.format_str(code, mode=DEFAULT_MODE)
1534
1535         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1536         result_diff = color_diff(result_diff)
1537
1538         args = ["--diff", "--color", "--code", code]
1539         result = CliRunner().invoke(black.main, args)
1540
1541         # Remove time from diff
1542         output = DIFF_TIME.sub("", result.output)
1543
1544         assert output == result_diff, "The output did not match the expected value."
1545         assert result.exit_code == 0, "The exit code is incorrect."
1546
1547     @pytest.mark.incompatible_with_mypyc
1548     def test_code_option_safe(self) -> None:
1549         """Test that the code option throws an error when the sanity checks fail."""
1550         # Patch black.assert_equivalent to ensure the sanity checks fail
1551         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1552             code = 'print("Hello world")'
1553             error_msg = f"{code}\nerror: cannot format <string>: \n"
1554
1555             args = ["--safe", "--code", code]
1556             result = CliRunner().invoke(black.main, args)
1557
1558             self.compare_results(result, error_msg, 123)
1559
1560     def test_code_option_fast(self) -> None:
1561         """Test that the code option ignores errors when the sanity checks fail."""
1562         # Patch black.assert_equivalent to ensure the sanity checks fail
1563         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1564             code = 'print("Hello world")'
1565             formatted = black.format_str(code, mode=DEFAULT_MODE)
1566
1567             args = ["--fast", "--code", code]
1568             result = CliRunner().invoke(black.main, args)
1569
1570             self.compare_results(result, formatted, 0)
1571
1572     @pytest.mark.incompatible_with_mypyc
1573     def test_code_option_config(self) -> None:
1574         """
1575         Test that the code option finds the pyproject.toml in the current directory.
1576         """
1577         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1578             args = ["--code", "print"]
1579             # This is the only directory known to contain a pyproject.toml
1580             with change_directory(PROJECT_ROOT):
1581                 CliRunner().invoke(black.main, args)
1582                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1583
1584             assert (
1585                 len(parse.mock_calls) >= 1
1586             ), "Expected config parse to be called with the current directory."
1587
1588             _, call_args, _ = parse.mock_calls[0]
1589             assert (
1590                 call_args[0].lower() == str(pyproject_path).lower()
1591             ), "Incorrect config loaded."
1592
1593     @pytest.mark.incompatible_with_mypyc
1594     def test_code_option_parent_config(self) -> None:
1595         """
1596         Test that the code option finds the pyproject.toml in the parent directory.
1597         """
1598         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1599             with change_directory(THIS_DIR):
1600                 args = ["--code", "print"]
1601                 CliRunner().invoke(black.main, args)
1602
1603                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1604                 assert (
1605                     len(parse.mock_calls) >= 1
1606                 ), "Expected config parse to be called with the current directory."
1607
1608                 _, call_args, _ = parse.mock_calls[0]
1609                 assert (
1610                     call_args[0].lower() == str(pyproject_path).lower()
1611                 ), "Incorrect config loaded."
1612
1613     def test_for_handled_unexpected_eof_error(self) -> None:
1614         """
1615         Test that an unexpected EOF SyntaxError is nicely presented.
1616         """
1617         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1618             black.lib2to3_parse("print(", {})
1619
1620         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1621
1622     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1623         with pytest.raises(AssertionError) as err:
1624             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1625
1626         err.match("--safe")
1627         # Unfortunately the SyntaxError message has changed in newer versions so we
1628         # can't match it directly.
1629         err.match("invalid character")
1630         err.match(r"\(<unknown>, line 1\)")
1631
1632
1633 class TestCaching:
1634     def test_get_cache_dir(
1635         self,
1636         tmp_path: Path,
1637         monkeypatch: pytest.MonkeyPatch,
1638     ) -> None:
1639         # Create multiple cache directories
1640         workspace1 = tmp_path / "ws1"
1641         workspace1.mkdir()
1642         workspace2 = tmp_path / "ws2"
1643         workspace2.mkdir()
1644
1645         # Force user_cache_dir to use the temporary directory for easier assertions
1646         patch_user_cache_dir = patch(
1647             target="black.cache.user_cache_dir",
1648             autospec=True,
1649             return_value=str(workspace1),
1650         )
1651
1652         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1653         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1654         with patch_user_cache_dir:
1655             assert get_cache_dir() == workspace1
1656
1657         # If it is set, use the path provided in the env var.
1658         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1659         assert get_cache_dir() == workspace2
1660
1661     def test_cache_broken_file(self) -> None:
1662         mode = DEFAULT_MODE
1663         with cache_dir() as workspace:
1664             cache_file = get_cache_file(mode)
1665             cache_file.write_text("this is not a pickle")
1666             assert black.read_cache(mode) == {}
1667             src = (workspace / "test.py").resolve()
1668             src.write_text("print('hello')")
1669             invokeBlack([str(src)])
1670             cache = black.read_cache(mode)
1671             assert str(src) in cache
1672
1673     def test_cache_single_file_already_cached(self) -> None:
1674         mode = DEFAULT_MODE
1675         with cache_dir() as workspace:
1676             src = (workspace / "test.py").resolve()
1677             src.write_text("print('hello')")
1678             black.write_cache({}, [src], mode)
1679             invokeBlack([str(src)])
1680             assert src.read_text() == "print('hello')"
1681
1682     @event_loop()
1683     def test_cache_multiple_files(self) -> None:
1684         mode = DEFAULT_MODE
1685         with cache_dir() as workspace, patch(
1686             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1687         ):
1688             one = (workspace / "one.py").resolve()
1689             with one.open("w") as fobj:
1690                 fobj.write("print('hello')")
1691             two = (workspace / "two.py").resolve()
1692             with two.open("w") as fobj:
1693                 fobj.write("print('hello')")
1694             black.write_cache({}, [one], mode)
1695             invokeBlack([str(workspace)])
1696             with one.open("r") as fobj:
1697                 assert fobj.read() == "print('hello')"
1698             with two.open("r") as fobj:
1699                 assert fobj.read() == 'print("hello")\n'
1700             cache = black.read_cache(mode)
1701             assert str(one) in cache
1702             assert str(two) in cache
1703
1704     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1705     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1706         mode = DEFAULT_MODE
1707         with cache_dir() as workspace:
1708             src = (workspace / "test.py").resolve()
1709             with src.open("w") as fobj:
1710                 fobj.write("print('hello')")
1711             with patch("black.read_cache") as read_cache, patch(
1712                 "black.write_cache"
1713             ) as write_cache:
1714                 cmd = [str(src), "--diff"]
1715                 if color:
1716                     cmd.append("--color")
1717                 invokeBlack(cmd)
1718                 cache_file = get_cache_file(mode)
1719                 assert cache_file.exists() is False
1720                 write_cache.assert_not_called()
1721                 read_cache.assert_not_called()
1722
1723     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1724     @event_loop()
1725     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1726         with cache_dir() as workspace:
1727             for tag in range(0, 4):
1728                 src = (workspace / f"test{tag}.py").resolve()
1729                 with src.open("w") as fobj:
1730                     fobj.write("print('hello')")
1731             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1732                 cmd = ["--diff", str(workspace)]
1733                 if color:
1734                     cmd.append("--color")
1735                 invokeBlack(cmd, exit_code=0)
1736                 # this isn't quite doing what we want, but if it _isn't_
1737                 # called then we cannot be using the lock it provides
1738                 mgr.assert_called()
1739
1740     def test_no_cache_when_stdin(self) -> None:
1741         mode = DEFAULT_MODE
1742         with cache_dir():
1743             result = CliRunner().invoke(
1744                 black.main, ["-"], input=BytesIO(b"print('hello')")
1745             )
1746             assert not result.exit_code
1747             cache_file = get_cache_file(mode)
1748             assert not cache_file.exists()
1749
1750     def test_read_cache_no_cachefile(self) -> None:
1751         mode = DEFAULT_MODE
1752         with cache_dir():
1753             assert black.read_cache(mode) == {}
1754
1755     def test_write_cache_read_cache(self) -> None:
1756         mode = DEFAULT_MODE
1757         with cache_dir() as workspace:
1758             src = (workspace / "test.py").resolve()
1759             src.touch()
1760             black.write_cache({}, [src], mode)
1761             cache = black.read_cache(mode)
1762             assert str(src) in cache
1763             assert cache[str(src)] == black.get_cache_info(src)
1764
1765     def test_filter_cached(self) -> None:
1766         with TemporaryDirectory() as workspace:
1767             path = Path(workspace)
1768             uncached = (path / "uncached").resolve()
1769             cached = (path / "cached").resolve()
1770             cached_but_changed = (path / "changed").resolve()
1771             uncached.touch()
1772             cached.touch()
1773             cached_but_changed.touch()
1774             cache = {
1775                 str(cached): black.get_cache_info(cached),
1776                 str(cached_but_changed): (0.0, 0),
1777             }
1778             todo, done = black.filter_cached(
1779                 cache, {uncached, cached, cached_but_changed}
1780             )
1781             assert todo == {uncached, cached_but_changed}
1782             assert done == {cached}
1783
1784     def test_write_cache_creates_directory_if_needed(self) -> None:
1785         mode = DEFAULT_MODE
1786         with cache_dir(exists=False) as workspace:
1787             assert not workspace.exists()
1788             black.write_cache({}, [], mode)
1789             assert workspace.exists()
1790
1791     @event_loop()
1792     def test_failed_formatting_does_not_get_cached(self) -> None:
1793         mode = DEFAULT_MODE
1794         with cache_dir() as workspace, patch(
1795             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1796         ):
1797             failing = (workspace / "failing.py").resolve()
1798             with failing.open("w") as fobj:
1799                 fobj.write("not actually python")
1800             clean = (workspace / "clean.py").resolve()
1801             with clean.open("w") as fobj:
1802                 fobj.write('print("hello")\n')
1803             invokeBlack([str(workspace)], exit_code=123)
1804             cache = black.read_cache(mode)
1805             assert str(failing) not in cache
1806             assert str(clean) in cache
1807
1808     def test_write_cache_write_fail(self) -> None:
1809         mode = DEFAULT_MODE
1810         with cache_dir(), patch.object(Path, "open") as mock:
1811             mock.side_effect = OSError
1812             black.write_cache({}, [], mode)
1813
1814     def test_read_cache_line_lengths(self) -> None:
1815         mode = DEFAULT_MODE
1816         short_mode = replace(DEFAULT_MODE, line_length=1)
1817         with cache_dir() as workspace:
1818             path = (workspace / "file.py").resolve()
1819             path.touch()
1820             black.write_cache({}, [path], mode)
1821             one = black.read_cache(mode)
1822             assert str(path) in one
1823             two = black.read_cache(short_mode)
1824             assert str(path) not in two
1825
1826
1827 def assert_collected_sources(
1828     src: Sequence[Union[str, Path]],
1829     expected: Sequence[Union[str, Path]],
1830     *,
1831     ctx: Optional[FakeContext] = None,
1832     exclude: Optional[str] = None,
1833     include: Optional[str] = None,
1834     extend_exclude: Optional[str] = None,
1835     force_exclude: Optional[str] = None,
1836     stdin_filename: Optional[str] = None,
1837 ) -> None:
1838     gs_src = tuple(str(Path(s)) for s in src)
1839     gs_expected = [Path(s) for s in expected]
1840     gs_exclude = None if exclude is None else compile_pattern(exclude)
1841     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1842     gs_extend_exclude = (
1843         None if extend_exclude is None else compile_pattern(extend_exclude)
1844     )
1845     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1846     collected = black.get_sources(
1847         ctx=ctx or FakeContext(),
1848         src=gs_src,
1849         quiet=False,
1850         verbose=False,
1851         include=gs_include,
1852         exclude=gs_exclude,
1853         extend_exclude=gs_extend_exclude,
1854         force_exclude=gs_force_exclude,
1855         report=black.Report(),
1856         stdin_filename=stdin_filename,
1857     )
1858     assert sorted(collected) == sorted(gs_expected)
1859
1860
1861 class TestFileCollection:
1862     def test_include_exclude(self) -> None:
1863         path = THIS_DIR / "data" / "include_exclude_tests"
1864         src = [path]
1865         expected = [
1866             Path(path / "b/dont_exclude/a.py"),
1867             Path(path / "b/dont_exclude/a.pyi"),
1868         ]
1869         assert_collected_sources(
1870             src,
1871             expected,
1872             include=r"\.pyi?$",
1873             exclude=r"/exclude/|/\.definitely_exclude/",
1874         )
1875
1876     def test_gitignore_used_as_default(self) -> None:
1877         base = Path(DATA_DIR / "include_exclude_tests")
1878         expected = [
1879             base / "b/.definitely_exclude/a.py",
1880             base / "b/.definitely_exclude/a.pyi",
1881         ]
1882         src = [base / "b/"]
1883         ctx = FakeContext()
1884         ctx.obj["root"] = base
1885         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1886
1887     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1888     def test_exclude_for_issue_1572(self) -> None:
1889         # Exclude shouldn't touch files that were explicitly given to Black through the
1890         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1891         # https://github.com/psf/black/issues/1572
1892         path = DATA_DIR / "include_exclude_tests"
1893         src = [path / "b/exclude/a.py"]
1894         expected = [path / "b/exclude/a.py"]
1895         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1896
1897     def test_gitignore_exclude(self) -> None:
1898         path = THIS_DIR / "data" / "include_exclude_tests"
1899         include = re.compile(r"\.pyi?$")
1900         exclude = re.compile(r"")
1901         report = black.Report()
1902         gitignore = PathSpec.from_lines(
1903             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1904         )
1905         sources: List[Path] = []
1906         expected = [
1907             Path(path / "b/dont_exclude/a.py"),
1908             Path(path / "b/dont_exclude/a.pyi"),
1909         ]
1910         this_abs = THIS_DIR.resolve()
1911         sources.extend(
1912             black.gen_python_files(
1913                 path.iterdir(),
1914                 this_abs,
1915                 include,
1916                 exclude,
1917                 None,
1918                 None,
1919                 report,
1920                 gitignore,
1921                 verbose=False,
1922                 quiet=False,
1923             )
1924         )
1925         assert sorted(expected) == sorted(sources)
1926
1927     def test_nested_gitignore(self) -> None:
1928         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1929         include = re.compile(r"\.pyi?$")
1930         exclude = re.compile(r"")
1931         root_gitignore = black.files.get_gitignore(path)
1932         report = black.Report()
1933         expected: List[Path] = [
1934             Path(path / "x.py"),
1935             Path(path / "root/b.py"),
1936             Path(path / "root/c.py"),
1937             Path(path / "root/child/c.py"),
1938         ]
1939         this_abs = THIS_DIR.resolve()
1940         sources = list(
1941             black.gen_python_files(
1942                 path.iterdir(),
1943                 this_abs,
1944                 include,
1945                 exclude,
1946                 None,
1947                 None,
1948                 report,
1949                 root_gitignore,
1950                 verbose=False,
1951                 quiet=False,
1952             )
1953         )
1954         assert sorted(expected) == sorted(sources)
1955
1956     def test_invalid_gitignore(self) -> None:
1957         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1958         empty_config = path / "pyproject.toml"
1959         result = BlackRunner().invoke(
1960             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1961         )
1962         assert result.exit_code == 1
1963         assert result.stderr_bytes is not None
1964
1965         gitignore = path / ".gitignore"
1966         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1967
1968     def test_invalid_nested_gitignore(self) -> None:
1969         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1970         empty_config = path / "pyproject.toml"
1971         result = BlackRunner().invoke(
1972             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1973         )
1974         assert result.exit_code == 1
1975         assert result.stderr_bytes is not None
1976
1977         gitignore = path / "a" / ".gitignore"
1978         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1979
1980     def test_empty_include(self) -> None:
1981         path = DATA_DIR / "include_exclude_tests"
1982         src = [path]
1983         expected = [
1984             Path(path / "b/exclude/a.pie"),
1985             Path(path / "b/exclude/a.py"),
1986             Path(path / "b/exclude/a.pyi"),
1987             Path(path / "b/dont_exclude/a.pie"),
1988             Path(path / "b/dont_exclude/a.py"),
1989             Path(path / "b/dont_exclude/a.pyi"),
1990             Path(path / "b/.definitely_exclude/a.pie"),
1991             Path(path / "b/.definitely_exclude/a.py"),
1992             Path(path / "b/.definitely_exclude/a.pyi"),
1993             Path(path / ".gitignore"),
1994             Path(path / "pyproject.toml"),
1995         ]
1996         # Setting exclude explicitly to an empty string to block .gitignore usage.
1997         assert_collected_sources(src, expected, include="", exclude="")
1998
1999     def test_extend_exclude(self) -> None:
2000         path = DATA_DIR / "include_exclude_tests"
2001         src = [path]
2002         expected = [
2003             Path(path / "b/exclude/a.py"),
2004             Path(path / "b/dont_exclude/a.py"),
2005         ]
2006         assert_collected_sources(
2007             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2008         )
2009
2010     @pytest.mark.incompatible_with_mypyc
2011     def test_symlink_out_of_root_directory(self) -> None:
2012         path = MagicMock()
2013         root = THIS_DIR.resolve()
2014         child = MagicMock()
2015         include = re.compile(black.DEFAULT_INCLUDES)
2016         exclude = re.compile(black.DEFAULT_EXCLUDES)
2017         report = black.Report()
2018         gitignore = PathSpec.from_lines("gitwildmatch", [])
2019         # `child` should behave like a symlink which resolved path is clearly
2020         # outside of the `root` directory.
2021         path.iterdir.return_value = [child]
2022         child.resolve.return_value = Path("/a/b/c")
2023         child.as_posix.return_value = "/a/b/c"
2024         try:
2025             list(
2026                 black.gen_python_files(
2027                     path.iterdir(),
2028                     root,
2029                     include,
2030                     exclude,
2031                     None,
2032                     None,
2033                     report,
2034                     gitignore,
2035                     verbose=False,
2036                     quiet=False,
2037                 )
2038             )
2039         except ValueError as ve:
2040             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2041         path.iterdir.assert_called_once()
2042         child.resolve.assert_called_once()
2043
2044     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2045     def test_get_sources_with_stdin(self) -> None:
2046         src = ["-"]
2047         expected = ["-"]
2048         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2049
2050     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2051     def test_get_sources_with_stdin_filename(self) -> None:
2052         src = ["-"]
2053         stdin_filename = str(THIS_DIR / "data/collections.py")
2054         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2055         assert_collected_sources(
2056             src,
2057             expected,
2058             exclude=r"/exclude/a\.py",
2059             stdin_filename=stdin_filename,
2060         )
2061
2062     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2063     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2064         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2065         # file being passed directly. This is the same as
2066         # test_exclude_for_issue_1572
2067         path = DATA_DIR / "include_exclude_tests"
2068         src = ["-"]
2069         stdin_filename = str(path / "b/exclude/a.py")
2070         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2071         assert_collected_sources(
2072             src,
2073             expected,
2074             exclude=r"/exclude/|a\.py",
2075             stdin_filename=stdin_filename,
2076         )
2077
2078     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2079     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2080         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2081         # file being passed directly. This is the same as
2082         # test_exclude_for_issue_1572
2083         src = ["-"]
2084         path = THIS_DIR / "data" / "include_exclude_tests"
2085         stdin_filename = str(path / "b/exclude/a.py")
2086         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2087         assert_collected_sources(
2088             src,
2089             expected,
2090             extend_exclude=r"/exclude/|a\.py",
2091             stdin_filename=stdin_filename,
2092         )
2093
2094     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2095     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2096         # Force exclude should exclude the file when passing it through
2097         # stdin_filename
2098         path = THIS_DIR / "data" / "include_exclude_tests"
2099         stdin_filename = str(path / "b/exclude/a.py")
2100         assert_collected_sources(
2101             src=["-"],
2102             expected=[],
2103             force_exclude=r"/exclude/|a\.py",
2104             stdin_filename=stdin_filename,
2105         )
2106
2107
2108 try:
2109     with open(black.__file__, "r", encoding="utf-8") as _bf:
2110         black_source_lines = _bf.readlines()
2111 except UnicodeDecodeError:
2112     if not black.COMPILED:
2113         raise
2114
2115
2116 def tracefunc(
2117     frame: types.FrameType, event: str, arg: Any
2118 ) -> Callable[[types.FrameType, str, Any], Any]:
2119     """Show function calls `from black/__init__.py` as they happen.
2120
2121     Register this with `sys.settrace()` in a test you're debugging.
2122     """
2123     if event != "call":
2124         return tracefunc
2125
2126     stack = len(inspect.stack()) - 19
2127     stack *= 2
2128     filename = frame.f_code.co_filename
2129     lineno = frame.f_lineno
2130     func_sig_lineno = lineno - 1
2131     funcname = black_source_lines[func_sig_lineno].strip()
2132     while funcname.startswith("@"):
2133         func_sig_lineno += 1
2134         funcname = black_source_lines[func_sig_lineno].strip()
2135     if "black/__init__.py" in filename:
2136         print(f"{' ' * stack}{lineno}:{funcname}")
2137     return tracefunc