]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

02a707e8996aa25beab3137a91281623126f0af5
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63     get_case_path,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     @patch("black.dump_to_file", dump_to_stderr)
314     def test_string_quotes(self) -> None:
315         source, expected = read_data("miscellaneous", "string_quotes")
316         mode = black.Mode(preview=True)
317         assert_format(source, expected, mode)
318         mode = replace(mode, string_normalization=False)
319         not_normalized = fs(source, mode=mode)
320         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
321         black.assert_equivalent(source, not_normalized)
322         black.assert_stable(source, not_normalized, mode=mode)
323
324     def test_skip_magic_trailing_comma(self) -> None:
325         source, _ = read_data("simple_cases", "expression")
326         expected, _ = read_data(
327             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
328         )
329         tmp_file = Path(black.dump_to_file(source))
330         diff_header = re.compile(
331             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
332             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
333         )
334         try:
335             result = BlackRunner().invoke(
336                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
337             )
338             self.assertEqual(result.exit_code, 0)
339         finally:
340             os.unlink(tmp_file)
341         actual = result.output
342         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
343         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
344         if expected != actual:
345             dump = black.dump_to_file(actual)
346             msg = (
347                 "Expected diff isn't equal to the actual. If you made changes to"
348                 " expression.py and this is an anticipated difference, overwrite"
349                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
350             )
351             self.assertEqual(expected, actual, msg)
352
353     @patch("black.dump_to_file", dump_to_stderr)
354     def test_async_as_identifier(self) -> None:
355         source_path = get_case_path("miscellaneous", "async_as_identifier")
356         source, expected = read_data_from_file(source_path)
357         actual = fs(source)
358         self.assertFormatEqual(expected, actual)
359         major, minor = sys.version_info[:2]
360         if major < 3 or (major <= 3 and minor < 7):
361             black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, DEFAULT_MODE)
363         # ensure black can parse this when the target is 3.6
364         self.invokeBlack([str(source_path), "--target-version", "py36"])
365         # but not on 3.7, because async/await is no longer an identifier
366         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_python37(self) -> None:
370         source_path = get_case_path("py_37", "python37")
371         source, expected = read_data_from_file(source_path)
372         actual = fs(source)
373         self.assertFormatEqual(expected, actual)
374         major, minor = sys.version_info[:2]
375         if major > 3 or (major == 3 and minor >= 7):
376             black.assert_equivalent(source, actual)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378         # ensure black can parse this when the target is 3.7
379         self.invokeBlack([str(source_path), "--target-version", "py37"])
380         # but not on 3.6, because we use async as a reserved keyword
381         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
382
383     def test_tab_comment_indentation(self) -> None:
384         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
385         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
386         self.assertFormatEqual(contents_spc, fs(contents_spc))
387         self.assertFormatEqual(contents_spc, fs(contents_tab))
388
389         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
390         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
391         self.assertFormatEqual(contents_spc, fs(contents_spc))
392         self.assertFormatEqual(contents_spc, fs(contents_tab))
393
394         # mixed tabs and spaces (valid Python 2 code)
395         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
396         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
397         self.assertFormatEqual(contents_spc, fs(contents_spc))
398         self.assertFormatEqual(contents_spc, fs(contents_tab))
399
400         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
401         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
402         self.assertFormatEqual(contents_spc, fs(contents_spc))
403         self.assertFormatEqual(contents_spc, fs(contents_tab))
404
405     def test_report_verbose(self) -> None:
406         report = Report(verbose=True)
407         out_lines = []
408         err_lines = []
409
410         def out(msg: str, **kwargs: Any) -> None:
411             out_lines.append(msg)
412
413         def err(msg: str, **kwargs: Any) -> None:
414             err_lines.append(msg)
415
416         with patch("black.output._out", out), patch("black.output._err", err):
417             report.done(Path("f1"), black.Changed.NO)
418             self.assertEqual(len(out_lines), 1)
419             self.assertEqual(len(err_lines), 0)
420             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
421             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
422             self.assertEqual(report.return_code, 0)
423             report.done(Path("f2"), black.Changed.YES)
424             self.assertEqual(len(out_lines), 2)
425             self.assertEqual(len(err_lines), 0)
426             self.assertEqual(out_lines[-1], "reformatted f2")
427             self.assertEqual(
428                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
429             )
430             report.done(Path("f3"), black.Changed.CACHED)
431             self.assertEqual(len(out_lines), 3)
432             self.assertEqual(len(err_lines), 0)
433             self.assertEqual(
434                 out_lines[-1], "f3 wasn't modified on disk since last run."
435             )
436             self.assertEqual(
437                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
438             )
439             self.assertEqual(report.return_code, 0)
440             report.check = True
441             self.assertEqual(report.return_code, 1)
442             report.check = False
443             report.failed(Path("e1"), "boom")
444             self.assertEqual(len(out_lines), 3)
445             self.assertEqual(len(err_lines), 1)
446             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
447             self.assertEqual(
448                 unstyle(str(report)),
449                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
450                 " reformat.",
451             )
452             self.assertEqual(report.return_code, 123)
453             report.done(Path("f3"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 4)
455             self.assertEqual(len(err_lines), 1)
456             self.assertEqual(out_lines[-1], "reformatted f3")
457             self.assertEqual(
458                 unstyle(str(report)),
459                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
460                 " reformat.",
461             )
462             self.assertEqual(report.return_code, 123)
463             report.failed(Path("e2"), "boom")
464             self.assertEqual(len(out_lines), 4)
465             self.assertEqual(len(err_lines), 2)
466             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
467             self.assertEqual(
468                 unstyle(str(report)),
469                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
470                 " reformat.",
471             )
472             self.assertEqual(report.return_code, 123)
473             report.path_ignored(Path("wat"), "no match")
474             self.assertEqual(len(out_lines), 5)
475             self.assertEqual(len(err_lines), 2)
476             self.assertEqual(out_lines[-1], "wat ignored: no match")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f4"), black.Changed.NO)
484             self.assertEqual(len(out_lines), 6)
485             self.assertEqual(len(err_lines), 2)
486             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.check = True
494             self.assertEqual(
495                 unstyle(str(report)),
496                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
497                 " would fail to reformat.",
498             )
499             report.check = False
500             report.diff = True
501             self.assertEqual(
502                 unstyle(str(report)),
503                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
504                 " would fail to reformat.",
505             )
506
507     def test_report_quiet(self) -> None:
508         report = Report(quiet=True)
509         out_lines = []
510         err_lines = []
511
512         def out(msg: str, **kwargs: Any) -> None:
513             out_lines.append(msg)
514
515         def err(msg: str, **kwargs: Any) -> None:
516             err_lines.append(msg)
517
518         with patch("black.output._out", out), patch("black.output._err", err):
519             report.done(Path("f1"), black.Changed.NO)
520             self.assertEqual(len(out_lines), 0)
521             self.assertEqual(len(err_lines), 0)
522             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
523             self.assertEqual(report.return_code, 0)
524             report.done(Path("f2"), black.Changed.YES)
525             self.assertEqual(len(out_lines), 0)
526             self.assertEqual(len(err_lines), 0)
527             self.assertEqual(
528                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
529             )
530             report.done(Path("f3"), black.Changed.CACHED)
531             self.assertEqual(len(out_lines), 0)
532             self.assertEqual(len(err_lines), 0)
533             self.assertEqual(
534                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
535             )
536             self.assertEqual(report.return_code, 0)
537             report.check = True
538             self.assertEqual(report.return_code, 1)
539             report.check = False
540             report.failed(Path("e1"), "boom")
541             self.assertEqual(len(out_lines), 0)
542             self.assertEqual(len(err_lines), 1)
543             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
544             self.assertEqual(
545                 unstyle(str(report)),
546                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
547                 " reformat.",
548             )
549             self.assertEqual(report.return_code, 123)
550             report.done(Path("f3"), black.Changed.YES)
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 1)
553             self.assertEqual(
554                 unstyle(str(report)),
555                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
556                 " reformat.",
557             )
558             self.assertEqual(report.return_code, 123)
559             report.failed(Path("e2"), "boom")
560             self.assertEqual(len(out_lines), 0)
561             self.assertEqual(len(err_lines), 2)
562             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
563             self.assertEqual(
564                 unstyle(str(report)),
565                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
566                 " reformat.",
567             )
568             self.assertEqual(report.return_code, 123)
569             report.path_ignored(Path("wat"), "no match")
570             self.assertEqual(len(out_lines), 0)
571             self.assertEqual(len(err_lines), 2)
572             self.assertEqual(
573                 unstyle(str(report)),
574                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
575                 " reformat.",
576             )
577             self.assertEqual(report.return_code, 123)
578             report.done(Path("f4"), black.Changed.NO)
579             self.assertEqual(len(out_lines), 0)
580             self.assertEqual(len(err_lines), 2)
581             self.assertEqual(
582                 unstyle(str(report)),
583                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
584                 " reformat.",
585             )
586             self.assertEqual(report.return_code, 123)
587             report.check = True
588             self.assertEqual(
589                 unstyle(str(report)),
590                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
591                 " would fail to reformat.",
592             )
593             report.check = False
594             report.diff = True
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
598                 " would fail to reformat.",
599             )
600
601     def test_report_normal(self) -> None:
602         report = black.Report()
603         out_lines = []
604         err_lines = []
605
606         def out(msg: str, **kwargs: Any) -> None:
607             out_lines.append(msg)
608
609         def err(msg: str, **kwargs: Any) -> None:
610             err_lines.append(msg)
611
612         with patch("black.output._out", out), patch("black.output._err", err):
613             report.done(Path("f1"), black.Changed.NO)
614             self.assertEqual(len(out_lines), 0)
615             self.assertEqual(len(err_lines), 0)
616             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
617             self.assertEqual(report.return_code, 0)
618             report.done(Path("f2"), black.Changed.YES)
619             self.assertEqual(len(out_lines), 1)
620             self.assertEqual(len(err_lines), 0)
621             self.assertEqual(out_lines[-1], "reformatted f2")
622             self.assertEqual(
623                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
624             )
625             report.done(Path("f3"), black.Changed.CACHED)
626             self.assertEqual(len(out_lines), 1)
627             self.assertEqual(len(err_lines), 0)
628             self.assertEqual(out_lines[-1], "reformatted f2")
629             self.assertEqual(
630                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
631             )
632             self.assertEqual(report.return_code, 0)
633             report.check = True
634             self.assertEqual(report.return_code, 1)
635             report.check = False
636             report.failed(Path("e1"), "boom")
637             self.assertEqual(len(out_lines), 1)
638             self.assertEqual(len(err_lines), 1)
639             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
643                 " reformat.",
644             )
645             self.assertEqual(report.return_code, 123)
646             report.done(Path("f3"), black.Changed.YES)
647             self.assertEqual(len(out_lines), 2)
648             self.assertEqual(len(err_lines), 1)
649             self.assertEqual(out_lines[-1], "reformatted f3")
650             self.assertEqual(
651                 unstyle(str(report)),
652                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
653                 " reformat.",
654             )
655             self.assertEqual(report.return_code, 123)
656             report.failed(Path("e2"), "boom")
657             self.assertEqual(len(out_lines), 2)
658             self.assertEqual(len(err_lines), 2)
659             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.path_ignored(Path("wat"), "no match")
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 2)
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
672                 " reformat.",
673             )
674             self.assertEqual(report.return_code, 123)
675             report.done(Path("f4"), black.Changed.NO)
676             self.assertEqual(len(out_lines), 2)
677             self.assertEqual(len(err_lines), 2)
678             self.assertEqual(
679                 unstyle(str(report)),
680                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
681                 " reformat.",
682             )
683             self.assertEqual(report.return_code, 123)
684             report.check = True
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
688                 " would fail to reformat.",
689             )
690             report.check = False
691             report.diff = True
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
695                 " would fail to reformat.",
696             )
697
698     def test_lib2to3_parse(self) -> None:
699         with self.assertRaises(black.InvalidInput):
700             black.lib2to3_parse("invalid syntax")
701
702         straddling = "x + y"
703         black.lib2to3_parse(straddling)
704         black.lib2to3_parse(straddling, {TargetVersion.PY36})
705
706         py2_only = "print x"
707         with self.assertRaises(black.InvalidInput):
708             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
709
710         py3_only = "exec(x, end=y)"
711         black.lib2to3_parse(py3_only)
712         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
713
714     def test_get_features_used_decorator(self) -> None:
715         # Test the feature detection of new decorator syntax
716         # since this makes some test cases of test_get_features_used()
717         # fails if it fails, this is tested first so that a useful case
718         # is identified
719         simples, relaxed = read_data("miscellaneous", "decorators")
720         # skip explanation comments at the top of the file
721         for simple_test in simples.split("##")[1:]:
722             node = black.lib2to3_parse(simple_test)
723             decorator = str(node.children[0].children[0]).strip()
724             self.assertNotIn(
725                 Feature.RELAXED_DECORATORS,
726                 black.get_features_used(node),
727                 msg=(
728                     f"decorator '{decorator}' follows python<=3.8 syntax"
729                     "but is detected as 3.9+"
730                     # f"The full node is\n{node!r}"
731                 ),
732             )
733         # skip the '# output' comment at the top of the output part
734         for relaxed_test in relaxed.split("##")[1:]:
735             node = black.lib2to3_parse(relaxed_test)
736             decorator = str(node.children[0].children[0]).strip()
737             self.assertIn(
738                 Feature.RELAXED_DECORATORS,
739                 black.get_features_used(node),
740                 msg=(
741                     f"decorator '{decorator}' uses python3.9+ syntax"
742                     "but is detected as python<=3.8"
743                     # f"The full node is\n{node!r}"
744                 ),
745             )
746
747     def test_get_features_used(self) -> None:
748         node = black.lib2to3_parse("def f(*, arg): ...\n")
749         self.assertEqual(black.get_features_used(node), set())
750         node = black.lib2to3_parse("def f(*, arg,): ...\n")
751         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
752         node = black.lib2to3_parse("f(*arg,)\n")
753         self.assertEqual(
754             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
755         )
756         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
757         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
758         node = black.lib2to3_parse("123_456\n")
759         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
760         node = black.lib2to3_parse("123456\n")
761         self.assertEqual(black.get_features_used(node), set())
762         source, expected = read_data("simple_cases", "function")
763         node = black.lib2to3_parse(source)
764         expected_features = {
765             Feature.TRAILING_COMMA_IN_CALL,
766             Feature.TRAILING_COMMA_IN_DEF,
767             Feature.F_STRINGS,
768         }
769         self.assertEqual(black.get_features_used(node), expected_features)
770         node = black.lib2to3_parse(expected)
771         self.assertEqual(black.get_features_used(node), expected_features)
772         source, expected = read_data("simple_cases", "expression")
773         node = black.lib2to3_parse(source)
774         self.assertEqual(black.get_features_used(node), set())
775         node = black.lib2to3_parse(expected)
776         self.assertEqual(black.get_features_used(node), set())
777         node = black.lib2to3_parse("lambda a, /, b: ...")
778         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
779         node = black.lib2to3_parse("def fn(a, /, b): ...")
780         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
781         node = black.lib2to3_parse("def fn(): yield a, b")
782         self.assertEqual(black.get_features_used(node), set())
783         node = black.lib2to3_parse("def fn(): return a, b")
784         self.assertEqual(black.get_features_used(node), set())
785         node = black.lib2to3_parse("def fn(): yield *b, c")
786         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
787         node = black.lib2to3_parse("def fn(): return a, *b, c")
788         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
789         node = black.lib2to3_parse("x = a, *b, c")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = regular")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Any = (regular, regular)")
794         self.assertEqual(black.get_features_used(node), set())
795         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
796         self.assertEqual(black.get_features_used(node), set())
797         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
798         self.assertEqual(
799             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
800         )
801         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
802         self.assertEqual(black.get_features_used(node), set())
803         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
806         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
807         node = black.lib2to3_parse("a[*b]")
808         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
809         node = black.lib2to3_parse("a[x, *y(), z] = t")
810         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
811         node = black.lib2to3_parse("def fn(*args: *T): pass")
812         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
813
814     def test_get_features_used_for_future_flags(self) -> None:
815         for src, features in [
816             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
817             (
818                 "from __future__ import (other, annotations)",
819                 {Feature.FUTURE_ANNOTATIONS},
820             ),
821             ("a = 1 + 2\nfrom something import annotations", set()),
822             ("from __future__ import x, y", set()),
823         ]:
824             with self.subTest(src=src, features=features):
825                 node = black.lib2to3_parse(src)
826                 future_imports = black.get_future_imports(node)
827                 self.assertEqual(
828                     black.get_features_used(node, future_imports=future_imports),
829                     features,
830                 )
831
832     def test_get_future_imports(self) -> None:
833         node = black.lib2to3_parse("\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from __future__ import black\n")
836         self.assertEqual({"black"}, black.get_future_imports(node))
837         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
838         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
839         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
840         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import multiple\nfrom __future__ import imports\n"
843         )
844         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
846         self.assertEqual({"black"}, black.get_future_imports(node))
847         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
848         self.assertEqual({"black"}, black.get_future_imports(node))
849         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
850         self.assertEqual(set(), black.get_future_imports(node))
851         node = black.lib2to3_parse("from some.module import black\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse(
854             "from __future__ import unicode_literals as _unicode_literals"
855         )
856         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
857         node = black.lib2to3_parse(
858             "from __future__ import unicode_literals as _lol, print"
859         )
860         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
861
862     @pytest.mark.incompatible_with_mypyc
863     def test_debug_visitor(self) -> None:
864         source, _ = read_data("miscellaneous", "debug_visitor")
865         expected, _ = read_data("miscellaneous", "debug_visitor.out")
866         out_lines = []
867         err_lines = []
868
869         def out(msg: str, **kwargs: Any) -> None:
870             out_lines.append(msg)
871
872         def err(msg: str, **kwargs: Any) -> None:
873             err_lines.append(msg)
874
875         with patch("black.debug.out", out):
876             DebugVisitor.show(source)
877         actual = "\n".join(out_lines) + "\n"
878         log_name = ""
879         if expected != actual:
880             log_name = black.dump_to_file(*out_lines)
881         self.assertEqual(
882             expected,
883             actual,
884             f"AST print out is different. Actual version dumped to {log_name}",
885         )
886
887     def test_format_file_contents(self) -> None:
888         empty = ""
889         mode = DEFAULT_MODE
890         with self.assertRaises(black.NothingChanged):
891             black.format_file_contents(empty, mode=mode, fast=False)
892         just_nl = "\n"
893         with self.assertRaises(black.NothingChanged):
894             black.format_file_contents(just_nl, mode=mode, fast=False)
895         same = "j = [1, 2, 3]\n"
896         with self.assertRaises(black.NothingChanged):
897             black.format_file_contents(same, mode=mode, fast=False)
898         different = "j = [1,2,3]"
899         expected = same
900         actual = black.format_file_contents(different, mode=mode, fast=False)
901         self.assertEqual(expected, actual)
902         invalid = "return if you can"
903         with self.assertRaises(black.InvalidInput) as e:
904             black.format_file_contents(invalid, mode=mode, fast=False)
905         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
906
907     def test_endmarker(self) -> None:
908         n = black.lib2to3_parse("\n")
909         self.assertEqual(n.type, black.syms.file_input)
910         self.assertEqual(len(n.children), 1)
911         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
912
913     @pytest.mark.incompatible_with_mypyc
914     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
915     def test_assertFormatEqual(self) -> None:
916         out_lines = []
917         err_lines = []
918
919         def out(msg: str, **kwargs: Any) -> None:
920             out_lines.append(msg)
921
922         def err(msg: str, **kwargs: Any) -> None:
923             err_lines.append(msg)
924
925         with patch("black.output._out", out), patch("black.output._err", err):
926             with self.assertRaises(AssertionError):
927                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
928
929         out_str = "".join(out_lines)
930         self.assertIn("Expected tree:", out_str)
931         self.assertIn("Actual tree:", out_str)
932         self.assertEqual("".join(err_lines), "")
933
934     @event_loop()
935     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
936     def test_works_in_mono_process_only_environment(self) -> None:
937         with cache_dir() as workspace:
938             for f in [
939                 (workspace / "one.py").resolve(),
940                 (workspace / "two.py").resolve(),
941             ]:
942                 f.write_text('print("hello")\n')
943             self.invokeBlack([str(workspace)])
944
945     @event_loop()
946     def test_check_diff_use_together(self) -> None:
947         with cache_dir():
948             # Files which will be reformatted.
949             src1 = get_case_path("miscellaneous", "string_quotes")
950             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
951             # Files which will not be reformatted.
952             src2 = get_case_path("simple_cases", "composition")
953             self.invokeBlack([str(src2), "--diff", "--check"])
954             # Multi file command.
955             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
956
957     def test_no_src_fails(self) -> None:
958         with cache_dir():
959             self.invokeBlack([], exit_code=1)
960
961     def test_src_and_code_fails(self) -> None:
962         with cache_dir():
963             self.invokeBlack([".", "-c", "0"], exit_code=1)
964
965     def test_broken_symlink(self) -> None:
966         with cache_dir() as workspace:
967             symlink = workspace / "broken_link.py"
968             try:
969                 symlink.symlink_to("nonexistent.py")
970             except (OSError, NotImplementedError) as e:
971                 self.skipTest(f"Can't create symlinks: {e}")
972             self.invokeBlack([str(workspace.resolve())])
973
974     def test_single_file_force_pyi(self) -> None:
975         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
976         contents, expected = read_data("miscellaneous", "force_pyi")
977         with cache_dir() as workspace:
978             path = (workspace / "file.py").resolve()
979             with open(path, "w") as fh:
980                 fh.write(contents)
981             self.invokeBlack([str(path), "--pyi"])
982             with open(path, "r") as fh:
983                 actual = fh.read()
984             # verify cache with --pyi is separate
985             pyi_cache = black.read_cache(pyi_mode)
986             self.assertIn(str(path), pyi_cache)
987             normal_cache = black.read_cache(DEFAULT_MODE)
988             self.assertNotIn(str(path), normal_cache)
989         self.assertFormatEqual(expected, actual)
990         black.assert_equivalent(contents, actual)
991         black.assert_stable(contents, actual, pyi_mode)
992
993     @event_loop()
994     def test_multi_file_force_pyi(self) -> None:
995         reg_mode = DEFAULT_MODE
996         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
997         contents, expected = read_data("miscellaneous", "force_pyi")
998         with cache_dir() as workspace:
999             paths = [
1000                 (workspace / "file1.py").resolve(),
1001                 (workspace / "file2.py").resolve(),
1002             ]
1003             for path in paths:
1004                 with open(path, "w") as fh:
1005                     fh.write(contents)
1006             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1007             for path in paths:
1008                 with open(path, "r") as fh:
1009                     actual = fh.read()
1010                 self.assertEqual(actual, expected)
1011             # verify cache with --pyi is separate
1012             pyi_cache = black.read_cache(pyi_mode)
1013             normal_cache = black.read_cache(reg_mode)
1014             for path in paths:
1015                 self.assertIn(str(path), pyi_cache)
1016                 self.assertNotIn(str(path), normal_cache)
1017
1018     def test_pipe_force_pyi(self) -> None:
1019         source, expected = read_data("miscellaneous", "force_pyi")
1020         result = CliRunner().invoke(
1021             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1022         )
1023         self.assertEqual(result.exit_code, 0)
1024         actual = result.output
1025         self.assertFormatEqual(actual, expected)
1026
1027     def test_single_file_force_py36(self) -> None:
1028         reg_mode = DEFAULT_MODE
1029         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1030         source, expected = read_data("miscellaneous", "force_py36")
1031         with cache_dir() as workspace:
1032             path = (workspace / "file.py").resolve()
1033             with open(path, "w") as fh:
1034                 fh.write(source)
1035             self.invokeBlack([str(path), *PY36_ARGS])
1036             with open(path, "r") as fh:
1037                 actual = fh.read()
1038             # verify cache with --target-version is separate
1039             py36_cache = black.read_cache(py36_mode)
1040             self.assertIn(str(path), py36_cache)
1041             normal_cache = black.read_cache(reg_mode)
1042             self.assertNotIn(str(path), normal_cache)
1043         self.assertEqual(actual, expected)
1044
1045     @event_loop()
1046     def test_multi_file_force_py36(self) -> None:
1047         reg_mode = DEFAULT_MODE
1048         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1049         source, expected = read_data("miscellaneous", "force_py36")
1050         with cache_dir() as workspace:
1051             paths = [
1052                 (workspace / "file1.py").resolve(),
1053                 (workspace / "file2.py").resolve(),
1054             ]
1055             for path in paths:
1056                 with open(path, "w") as fh:
1057                     fh.write(source)
1058             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1059             for path in paths:
1060                 with open(path, "r") as fh:
1061                     actual = fh.read()
1062                 self.assertEqual(actual, expected)
1063             # verify cache with --target-version is separate
1064             pyi_cache = black.read_cache(py36_mode)
1065             normal_cache = black.read_cache(reg_mode)
1066             for path in paths:
1067                 self.assertIn(str(path), pyi_cache)
1068                 self.assertNotIn(str(path), normal_cache)
1069
1070     def test_pipe_force_py36(self) -> None:
1071         source, expected = read_data("miscellaneous", "force_py36")
1072         result = CliRunner().invoke(
1073             black.main,
1074             ["-", "-q", "--target-version=py36"],
1075             input=BytesIO(source.encode("utf8")),
1076         )
1077         self.assertEqual(result.exit_code, 0)
1078         actual = result.output
1079         self.assertFormatEqual(actual, expected)
1080
1081     @pytest.mark.incompatible_with_mypyc
1082     def test_reformat_one_with_stdin(self) -> None:
1083         with patch(
1084             "black.format_stdin_to_stdout",
1085             return_value=lambda *args, **kwargs: black.Changed.YES,
1086         ) as fsts:
1087             report = MagicMock()
1088             path = Path("-")
1089             black.reformat_one(
1090                 path,
1091                 fast=True,
1092                 write_back=black.WriteBack.YES,
1093                 mode=DEFAULT_MODE,
1094                 report=report,
1095             )
1096             fsts.assert_called_once()
1097             report.done.assert_called_with(path, black.Changed.YES)
1098
1099     @pytest.mark.incompatible_with_mypyc
1100     def test_reformat_one_with_stdin_filename(self) -> None:
1101         with patch(
1102             "black.format_stdin_to_stdout",
1103             return_value=lambda *args, **kwargs: black.Changed.YES,
1104         ) as fsts:
1105             report = MagicMock()
1106             p = "foo.py"
1107             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1108             expected = Path(p)
1109             black.reformat_one(
1110                 path,
1111                 fast=True,
1112                 write_back=black.WriteBack.YES,
1113                 mode=DEFAULT_MODE,
1114                 report=report,
1115             )
1116             fsts.assert_called_once_with(
1117                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1118             )
1119             # __BLACK_STDIN_FILENAME__ should have been stripped
1120             report.done.assert_called_with(expected, black.Changed.YES)
1121
1122     @pytest.mark.incompatible_with_mypyc
1123     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1124         with patch(
1125             "black.format_stdin_to_stdout",
1126             return_value=lambda *args, **kwargs: black.Changed.YES,
1127         ) as fsts:
1128             report = MagicMock()
1129             p = "foo.pyi"
1130             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1131             expected = Path(p)
1132             black.reformat_one(
1133                 path,
1134                 fast=True,
1135                 write_back=black.WriteBack.YES,
1136                 mode=DEFAULT_MODE,
1137                 report=report,
1138             )
1139             fsts.assert_called_once_with(
1140                 fast=True,
1141                 write_back=black.WriteBack.YES,
1142                 mode=replace(DEFAULT_MODE, is_pyi=True),
1143             )
1144             # __BLACK_STDIN_FILENAME__ should have been stripped
1145             report.done.assert_called_with(expected, black.Changed.YES)
1146
1147     @pytest.mark.incompatible_with_mypyc
1148     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1149         with patch(
1150             "black.format_stdin_to_stdout",
1151             return_value=lambda *args, **kwargs: black.Changed.YES,
1152         ) as fsts:
1153             report = MagicMock()
1154             p = "foo.ipynb"
1155             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1156             expected = Path(p)
1157             black.reformat_one(
1158                 path,
1159                 fast=True,
1160                 write_back=black.WriteBack.YES,
1161                 mode=DEFAULT_MODE,
1162                 report=report,
1163             )
1164             fsts.assert_called_once_with(
1165                 fast=True,
1166                 write_back=black.WriteBack.YES,
1167                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1168             )
1169             # __BLACK_STDIN_FILENAME__ should have been stripped
1170             report.done.assert_called_with(expected, black.Changed.YES)
1171
1172     @pytest.mark.incompatible_with_mypyc
1173     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1174         with patch(
1175             "black.format_stdin_to_stdout",
1176             return_value=lambda *args, **kwargs: black.Changed.YES,
1177         ) as fsts:
1178             report = MagicMock()
1179             # Even with an existing file, since we are forcing stdin, black
1180             # should output to stdout and not modify the file inplace
1181             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1182             # Make sure is_file actually returns True
1183             self.assertTrue(p.is_file())
1184             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1185             expected = Path(p)
1186             black.reformat_one(
1187                 path,
1188                 fast=True,
1189                 write_back=black.WriteBack.YES,
1190                 mode=DEFAULT_MODE,
1191                 report=report,
1192             )
1193             fsts.assert_called_once()
1194             # __BLACK_STDIN_FILENAME__ should have been stripped
1195             report.done.assert_called_with(expected, black.Changed.YES)
1196
1197     def test_reformat_one_with_stdin_empty(self) -> None:
1198         output = io.StringIO()
1199         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1200             try:
1201                 black.format_stdin_to_stdout(
1202                     fast=True,
1203                     content="",
1204                     write_back=black.WriteBack.YES,
1205                     mode=DEFAULT_MODE,
1206                 )
1207             except io.UnsupportedOperation:
1208                 pass  # StringIO does not support detach
1209             assert output.getvalue() == ""
1210
1211     def test_invalid_cli_regex(self) -> None:
1212         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1213             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1214
1215     def test_required_version_matches_version(self) -> None:
1216         self.invokeBlack(
1217             ["--required-version", black.__version__, "-c", "0"],
1218             exit_code=0,
1219             ignore_config=True,
1220         )
1221
1222     def test_required_version_matches_partial_version(self) -> None:
1223         self.invokeBlack(
1224             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1225             exit_code=0,
1226             ignore_config=True,
1227         )
1228
1229     def test_required_version_does_not_match_on_minor_version(self) -> None:
1230         self.invokeBlack(
1231             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1232             exit_code=1,
1233             ignore_config=True,
1234         )
1235
1236     def test_required_version_does_not_match_version(self) -> None:
1237         result = BlackRunner().invoke(
1238             black.main,
1239             ["--required-version", "20.99b", "-c", "0"],
1240         )
1241         self.assertEqual(result.exit_code, 1)
1242         self.assertIn("required version", result.stderr)
1243
1244     def test_preserves_line_endings(self) -> None:
1245         with TemporaryDirectory() as workspace:
1246             test_file = Path(workspace) / "test.py"
1247             for nl in ["\n", "\r\n"]:
1248                 contents = nl.join(["def f(  ):", "    pass"])
1249                 test_file.write_bytes(contents.encode())
1250                 ff(test_file, write_back=black.WriteBack.YES)
1251                 updated_contents: bytes = test_file.read_bytes()
1252                 self.assertIn(nl.encode(), updated_contents)
1253                 if nl == "\n":
1254                     self.assertNotIn(b"\r\n", updated_contents)
1255
1256     def test_preserves_line_endings_via_stdin(self) -> None:
1257         for nl in ["\n", "\r\n"]:
1258             contents = nl.join(["def f(  ):", "    pass"])
1259             runner = BlackRunner()
1260             result = runner.invoke(
1261                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1262             )
1263             self.assertEqual(result.exit_code, 0)
1264             output = result.stdout_bytes
1265             self.assertIn(nl.encode("utf8"), output)
1266             if nl == "\n":
1267                 self.assertNotIn(b"\r\n", output)
1268
1269     def test_assert_equivalent_different_asts(self) -> None:
1270         with self.assertRaises(AssertionError):
1271             black.assert_equivalent("{}", "None")
1272
1273     def test_shhh_click(self) -> None:
1274         try:
1275             from click import _unicodefun  # type: ignore
1276         except ImportError:
1277             self.skipTest("Incompatible Click version")
1278
1279         if not hasattr(_unicodefun, "_verify_python_env"):
1280             self.skipTest("Incompatible Click version")
1281
1282         # First, let's see if Click is crashing with a preferred ASCII charset.
1283         with patch("locale.getpreferredencoding") as gpe:
1284             gpe.return_value = "ASCII"
1285             with self.assertRaises(RuntimeError):
1286                 _unicodefun._verify_python_env()
1287         # Now, let's silence Click...
1288         black.patch_click()
1289         # ...and confirm it's silent.
1290         with patch("locale.getpreferredencoding") as gpe:
1291             gpe.return_value = "ASCII"
1292             try:
1293                 _unicodefun._verify_python_env()
1294             except RuntimeError as re:
1295                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1296
1297     def test_root_logger_not_used_directly(self) -> None:
1298         def fail(*args: Any, **kwargs: Any) -> None:
1299             self.fail("Record created with root logger")
1300
1301         with patch.multiple(
1302             logging.root,
1303             debug=fail,
1304             info=fail,
1305             warning=fail,
1306             error=fail,
1307             critical=fail,
1308             log=fail,
1309         ):
1310             ff(THIS_DIR / "util.py")
1311
1312     def test_invalid_config_return_code(self) -> None:
1313         tmp_file = Path(black.dump_to_file())
1314         try:
1315             tmp_config = Path(black.dump_to_file())
1316             tmp_config.unlink()
1317             args = ["--config", str(tmp_config), str(tmp_file)]
1318             self.invokeBlack(args, exit_code=2, ignore_config=False)
1319         finally:
1320             tmp_file.unlink()
1321
1322     def test_parse_pyproject_toml(self) -> None:
1323         test_toml_file = THIS_DIR / "test.toml"
1324         config = black.parse_pyproject_toml(str(test_toml_file))
1325         self.assertEqual(config["verbose"], 1)
1326         self.assertEqual(config["check"], "no")
1327         self.assertEqual(config["diff"], "y")
1328         self.assertEqual(config["color"], True)
1329         self.assertEqual(config["line_length"], 79)
1330         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1331         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1332         self.assertEqual(config["exclude"], r"\.pyi?$")
1333         self.assertEqual(config["include"], r"\.py?$")
1334
1335     def test_read_pyproject_toml(self) -> None:
1336         test_toml_file = THIS_DIR / "test.toml"
1337         fake_ctx = FakeContext()
1338         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1339         config = fake_ctx.default_map
1340         self.assertEqual(config["verbose"], "1")
1341         self.assertEqual(config["check"], "no")
1342         self.assertEqual(config["diff"], "y")
1343         self.assertEqual(config["color"], "True")
1344         self.assertEqual(config["line_length"], "79")
1345         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1346         self.assertEqual(config["exclude"], r"\.pyi?$")
1347         self.assertEqual(config["include"], r"\.py?$")
1348
1349     @pytest.mark.incompatible_with_mypyc
1350     def test_find_project_root(self) -> None:
1351         with TemporaryDirectory() as workspace:
1352             root = Path(workspace)
1353             test_dir = root / "test"
1354             test_dir.mkdir()
1355
1356             src_dir = root / "src"
1357             src_dir.mkdir()
1358
1359             root_pyproject = root / "pyproject.toml"
1360             root_pyproject.touch()
1361             src_pyproject = src_dir / "pyproject.toml"
1362             src_pyproject.touch()
1363             src_python = src_dir / "foo.py"
1364             src_python.touch()
1365
1366             self.assertEqual(
1367                 black.find_project_root((src_dir, test_dir)),
1368                 (root.resolve(), "pyproject.toml"),
1369             )
1370             self.assertEqual(
1371                 black.find_project_root((src_dir,)),
1372                 (src_dir.resolve(), "pyproject.toml"),
1373             )
1374             self.assertEqual(
1375                 black.find_project_root((src_python,)),
1376                 (src_dir.resolve(), "pyproject.toml"),
1377             )
1378
1379     @patch(
1380         "black.files.find_user_pyproject_toml",
1381     )
1382     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1383         find_user_pyproject_toml.side_effect = RuntimeError()
1384
1385         with redirect_stderr(io.StringIO()) as stderr:
1386             result = black.files.find_pyproject_toml(
1387                 path_search_start=(str(Path.cwd().root),)
1388             )
1389
1390         assert result is None
1391         err = stderr.getvalue()
1392         assert "Ignoring user configuration" in err
1393
1394     @patch(
1395         "black.files.find_user_pyproject_toml",
1396         black.files.find_user_pyproject_toml.__wrapped__,
1397     )
1398     def test_find_user_pyproject_toml_linux(self) -> None:
1399         if system() == "Windows":
1400             return
1401
1402         # Test if XDG_CONFIG_HOME is checked
1403         with TemporaryDirectory() as workspace:
1404             tmp_user_config = Path(workspace) / "black"
1405             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1406                 self.assertEqual(
1407                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1408                 )
1409
1410         # Test fallback for XDG_CONFIG_HOME
1411         with patch.dict("os.environ"):
1412             os.environ.pop("XDG_CONFIG_HOME", None)
1413             fallback_user_config = Path("~/.config").expanduser() / "black"
1414             self.assertEqual(
1415                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1416             )
1417
1418     def test_find_user_pyproject_toml_windows(self) -> None:
1419         if system() != "Windows":
1420             return
1421
1422         user_config_path = Path.home() / ".black"
1423         self.assertEqual(
1424             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1425         )
1426
1427     def test_bpo_33660_workaround(self) -> None:
1428         if system() == "Windows":
1429             return
1430
1431         # https://bugs.python.org/issue33660
1432         root = Path("/")
1433         with change_directory(root):
1434             path = Path("workspace") / "project"
1435             report = black.Report(verbose=True)
1436             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1437             self.assertEqual(normalized_path, "workspace/project")
1438
1439     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1440         if system() != "Windows":
1441             return
1442
1443         with TemporaryDirectory() as workspace:
1444             root = Path(workspace)
1445             junction_dir = root / "junction"
1446             junction_target_outside_of_root = root / ".."
1447             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1448
1449             report = black.Report(verbose=True)
1450             normalized_path = black.normalize_path_maybe_ignore(
1451                 junction_dir, root, report
1452             )
1453             # Manually delete for Python < 3.8
1454             os.system(f"rmdir {junction_dir}")
1455
1456             self.assertEqual(normalized_path, None)
1457
1458     def test_newline_comment_interaction(self) -> None:
1459         source = "class A:\\\r\n# type: ignore\n pass\n"
1460         output = black.format_str(source, mode=DEFAULT_MODE)
1461         black.assert_stable(source, output, mode=DEFAULT_MODE)
1462
1463     def test_bpo_2142_workaround(self) -> None:
1464
1465         # https://bugs.python.org/issue2142
1466
1467         source, _ = read_data("miscellaneous", "missing_final_newline")
1468         # read_data adds a trailing newline
1469         source = source.rstrip()
1470         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1471         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1472         diff_header = re.compile(
1473             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1474             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1475         )
1476         try:
1477             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1478             self.assertEqual(result.exit_code, 0)
1479         finally:
1480             os.unlink(tmp_file)
1481         actual = result.output
1482         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1483         self.assertEqual(actual, expected)
1484
1485     @staticmethod
1486     def compare_results(
1487         result: click.testing.Result, expected_value: str, expected_exit_code: int
1488     ) -> None:
1489         """Helper method to test the value and exit code of a click Result."""
1490         assert (
1491             result.output == expected_value
1492         ), "The output did not match the expected value."
1493         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1494
1495     def test_code_option(self) -> None:
1496         """Test the code option with no changes."""
1497         code = 'print("Hello world")\n'
1498         args = ["--code", code]
1499         result = CliRunner().invoke(black.main, args)
1500
1501         self.compare_results(result, code, 0)
1502
1503     def test_code_option_changed(self) -> None:
1504         """Test the code option when changes are required."""
1505         code = "print('hello world')"
1506         formatted = black.format_str(code, mode=DEFAULT_MODE)
1507
1508         args = ["--code", code]
1509         result = CliRunner().invoke(black.main, args)
1510
1511         self.compare_results(result, formatted, 0)
1512
1513     def test_code_option_check(self) -> None:
1514         """Test the code option when check is passed."""
1515         args = ["--check", "--code", 'print("Hello world")\n']
1516         result = CliRunner().invoke(black.main, args)
1517         self.compare_results(result, "", 0)
1518
1519     def test_code_option_check_changed(self) -> None:
1520         """Test the code option when changes are required, and check is passed."""
1521         args = ["--check", "--code", "print('hello world')"]
1522         result = CliRunner().invoke(black.main, args)
1523         self.compare_results(result, "", 1)
1524
1525     def test_code_option_diff(self) -> None:
1526         """Test the code option when diff is passed."""
1527         code = "print('hello world')"
1528         formatted = black.format_str(code, mode=DEFAULT_MODE)
1529         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1530
1531         args = ["--diff", "--code", code]
1532         result = CliRunner().invoke(black.main, args)
1533
1534         # Remove time from diff
1535         output = DIFF_TIME.sub("", result.output)
1536
1537         assert output == result_diff, "The output did not match the expected value."
1538         assert result.exit_code == 0, "The exit code is incorrect."
1539
1540     def test_code_option_color_diff(self) -> None:
1541         """Test the code option when color and diff are passed."""
1542         code = "print('hello world')"
1543         formatted = black.format_str(code, mode=DEFAULT_MODE)
1544
1545         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1546         result_diff = color_diff(result_diff)
1547
1548         args = ["--diff", "--color", "--code", code]
1549         result = CliRunner().invoke(black.main, args)
1550
1551         # Remove time from diff
1552         output = DIFF_TIME.sub("", result.output)
1553
1554         assert output == result_diff, "The output did not match the expected value."
1555         assert result.exit_code == 0, "The exit code is incorrect."
1556
1557     @pytest.mark.incompatible_with_mypyc
1558     def test_code_option_safe(self) -> None:
1559         """Test that the code option throws an error when the sanity checks fail."""
1560         # Patch black.assert_equivalent to ensure the sanity checks fail
1561         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1562             code = 'print("Hello world")'
1563             error_msg = f"{code}\nerror: cannot format <string>: \n"
1564
1565             args = ["--safe", "--code", code]
1566             result = CliRunner().invoke(black.main, args)
1567
1568             self.compare_results(result, error_msg, 123)
1569
1570     def test_code_option_fast(self) -> None:
1571         """Test that the code option ignores errors when the sanity checks fail."""
1572         # Patch black.assert_equivalent to ensure the sanity checks fail
1573         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1574             code = 'print("Hello world")'
1575             formatted = black.format_str(code, mode=DEFAULT_MODE)
1576
1577             args = ["--fast", "--code", code]
1578             result = CliRunner().invoke(black.main, args)
1579
1580             self.compare_results(result, formatted, 0)
1581
1582     @pytest.mark.incompatible_with_mypyc
1583     def test_code_option_config(self) -> None:
1584         """
1585         Test that the code option finds the pyproject.toml in the current directory.
1586         """
1587         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1588             args = ["--code", "print"]
1589             # This is the only directory known to contain a pyproject.toml
1590             with change_directory(PROJECT_ROOT):
1591                 CliRunner().invoke(black.main, args)
1592                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1593
1594             assert (
1595                 len(parse.mock_calls) >= 1
1596             ), "Expected config parse to be called with the current directory."
1597
1598             _, call_args, _ = parse.mock_calls[0]
1599             assert (
1600                 call_args[0].lower() == str(pyproject_path).lower()
1601             ), "Incorrect config loaded."
1602
1603     @pytest.mark.incompatible_with_mypyc
1604     def test_code_option_parent_config(self) -> None:
1605         """
1606         Test that the code option finds the pyproject.toml in the parent directory.
1607         """
1608         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1609             with change_directory(THIS_DIR):
1610                 args = ["--code", "print"]
1611                 CliRunner().invoke(black.main, args)
1612
1613                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1614                 assert (
1615                     len(parse.mock_calls) >= 1
1616                 ), "Expected config parse to be called with the current directory."
1617
1618                 _, call_args, _ = parse.mock_calls[0]
1619                 assert (
1620                     call_args[0].lower() == str(pyproject_path).lower()
1621                 ), "Incorrect config loaded."
1622
1623     def test_for_handled_unexpected_eof_error(self) -> None:
1624         """
1625         Test that an unexpected EOF SyntaxError is nicely presented.
1626         """
1627         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1628             black.lib2to3_parse("print(", {})
1629
1630         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1631
1632     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1633         with pytest.raises(AssertionError) as err:
1634             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1635
1636         err.match("--safe")
1637         # Unfortunately the SyntaxError message has changed in newer versions so we
1638         # can't match it directly.
1639         err.match("invalid character")
1640         err.match(r"\(<unknown>, line 1\)")
1641
1642
1643 class TestCaching:
1644     def test_get_cache_dir(
1645         self,
1646         tmp_path: Path,
1647         monkeypatch: pytest.MonkeyPatch,
1648     ) -> None:
1649         # Create multiple cache directories
1650         workspace1 = tmp_path / "ws1"
1651         workspace1.mkdir()
1652         workspace2 = tmp_path / "ws2"
1653         workspace2.mkdir()
1654
1655         # Force user_cache_dir to use the temporary directory for easier assertions
1656         patch_user_cache_dir = patch(
1657             target="black.cache.user_cache_dir",
1658             autospec=True,
1659             return_value=str(workspace1),
1660         )
1661
1662         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1663         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1664         with patch_user_cache_dir:
1665             assert get_cache_dir() == workspace1
1666
1667         # If it is set, use the path provided in the env var.
1668         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1669         assert get_cache_dir() == workspace2
1670
1671     def test_cache_broken_file(self) -> None:
1672         mode = DEFAULT_MODE
1673         with cache_dir() as workspace:
1674             cache_file = get_cache_file(mode)
1675             cache_file.write_text("this is not a pickle")
1676             assert black.read_cache(mode) == {}
1677             src = (workspace / "test.py").resolve()
1678             src.write_text("print('hello')")
1679             invokeBlack([str(src)])
1680             cache = black.read_cache(mode)
1681             assert str(src) in cache
1682
1683     def test_cache_single_file_already_cached(self) -> None:
1684         mode = DEFAULT_MODE
1685         with cache_dir() as workspace:
1686             src = (workspace / "test.py").resolve()
1687             src.write_text("print('hello')")
1688             black.write_cache({}, [src], mode)
1689             invokeBlack([str(src)])
1690             assert src.read_text() == "print('hello')"
1691
1692     @event_loop()
1693     def test_cache_multiple_files(self) -> None:
1694         mode = DEFAULT_MODE
1695         with cache_dir() as workspace, patch(
1696             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1697         ):
1698             one = (workspace / "one.py").resolve()
1699             with one.open("w") as fobj:
1700                 fobj.write("print('hello')")
1701             two = (workspace / "two.py").resolve()
1702             with two.open("w") as fobj:
1703                 fobj.write("print('hello')")
1704             black.write_cache({}, [one], mode)
1705             invokeBlack([str(workspace)])
1706             with one.open("r") as fobj:
1707                 assert fobj.read() == "print('hello')"
1708             with two.open("r") as fobj:
1709                 assert fobj.read() == 'print("hello")\n'
1710             cache = black.read_cache(mode)
1711             assert str(one) in cache
1712             assert str(two) in cache
1713
1714     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1715     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1716         mode = DEFAULT_MODE
1717         with cache_dir() as workspace:
1718             src = (workspace / "test.py").resolve()
1719             with src.open("w") as fobj:
1720                 fobj.write("print('hello')")
1721             with patch("black.read_cache") as read_cache, patch(
1722                 "black.write_cache"
1723             ) as write_cache:
1724                 cmd = [str(src), "--diff"]
1725                 if color:
1726                     cmd.append("--color")
1727                 invokeBlack(cmd)
1728                 cache_file = get_cache_file(mode)
1729                 assert cache_file.exists() is False
1730                 write_cache.assert_not_called()
1731                 read_cache.assert_not_called()
1732
1733     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1734     @event_loop()
1735     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1736         with cache_dir() as workspace:
1737             for tag in range(0, 4):
1738                 src = (workspace / f"test{tag}.py").resolve()
1739                 with src.open("w") as fobj:
1740                     fobj.write("print('hello')")
1741             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1742                 cmd = ["--diff", str(workspace)]
1743                 if color:
1744                     cmd.append("--color")
1745                 invokeBlack(cmd, exit_code=0)
1746                 # this isn't quite doing what we want, but if it _isn't_
1747                 # called then we cannot be using the lock it provides
1748                 mgr.assert_called()
1749
1750     def test_no_cache_when_stdin(self) -> None:
1751         mode = DEFAULT_MODE
1752         with cache_dir():
1753             result = CliRunner().invoke(
1754                 black.main, ["-"], input=BytesIO(b"print('hello')")
1755             )
1756             assert not result.exit_code
1757             cache_file = get_cache_file(mode)
1758             assert not cache_file.exists()
1759
1760     def test_read_cache_no_cachefile(self) -> None:
1761         mode = DEFAULT_MODE
1762         with cache_dir():
1763             assert black.read_cache(mode) == {}
1764
1765     def test_write_cache_read_cache(self) -> None:
1766         mode = DEFAULT_MODE
1767         with cache_dir() as workspace:
1768             src = (workspace / "test.py").resolve()
1769             src.touch()
1770             black.write_cache({}, [src], mode)
1771             cache = black.read_cache(mode)
1772             assert str(src) in cache
1773             assert cache[str(src)] == black.get_cache_info(src)
1774
1775     def test_filter_cached(self) -> None:
1776         with TemporaryDirectory() as workspace:
1777             path = Path(workspace)
1778             uncached = (path / "uncached").resolve()
1779             cached = (path / "cached").resolve()
1780             cached_but_changed = (path / "changed").resolve()
1781             uncached.touch()
1782             cached.touch()
1783             cached_but_changed.touch()
1784             cache = {
1785                 str(cached): black.get_cache_info(cached),
1786                 str(cached_but_changed): (0.0, 0),
1787             }
1788             todo, done = black.filter_cached(
1789                 cache, {uncached, cached, cached_but_changed}
1790             )
1791             assert todo == {uncached, cached_but_changed}
1792             assert done == {cached}
1793
1794     def test_write_cache_creates_directory_if_needed(self) -> None:
1795         mode = DEFAULT_MODE
1796         with cache_dir(exists=False) as workspace:
1797             assert not workspace.exists()
1798             black.write_cache({}, [], mode)
1799             assert workspace.exists()
1800
1801     @event_loop()
1802     def test_failed_formatting_does_not_get_cached(self) -> None:
1803         mode = DEFAULT_MODE
1804         with cache_dir() as workspace, patch(
1805             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1806         ):
1807             failing = (workspace / "failing.py").resolve()
1808             with failing.open("w") as fobj:
1809                 fobj.write("not actually python")
1810             clean = (workspace / "clean.py").resolve()
1811             with clean.open("w") as fobj:
1812                 fobj.write('print("hello")\n')
1813             invokeBlack([str(workspace)], exit_code=123)
1814             cache = black.read_cache(mode)
1815             assert str(failing) not in cache
1816             assert str(clean) in cache
1817
1818     def test_write_cache_write_fail(self) -> None:
1819         mode = DEFAULT_MODE
1820         with cache_dir(), patch.object(Path, "open") as mock:
1821             mock.side_effect = OSError
1822             black.write_cache({}, [], mode)
1823
1824     def test_read_cache_line_lengths(self) -> None:
1825         mode = DEFAULT_MODE
1826         short_mode = replace(DEFAULT_MODE, line_length=1)
1827         with cache_dir() as workspace:
1828             path = (workspace / "file.py").resolve()
1829             path.touch()
1830             black.write_cache({}, [path], mode)
1831             one = black.read_cache(mode)
1832             assert str(path) in one
1833             two = black.read_cache(short_mode)
1834             assert str(path) not in two
1835
1836
1837 def assert_collected_sources(
1838     src: Sequence[Union[str, Path]],
1839     expected: Sequence[Union[str, Path]],
1840     *,
1841     ctx: Optional[FakeContext] = None,
1842     exclude: Optional[str] = None,
1843     include: Optional[str] = None,
1844     extend_exclude: Optional[str] = None,
1845     force_exclude: Optional[str] = None,
1846     stdin_filename: Optional[str] = None,
1847 ) -> None:
1848     gs_src = tuple(str(Path(s)) for s in src)
1849     gs_expected = [Path(s) for s in expected]
1850     gs_exclude = None if exclude is None else compile_pattern(exclude)
1851     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1852     gs_extend_exclude = (
1853         None if extend_exclude is None else compile_pattern(extend_exclude)
1854     )
1855     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1856     collected = black.get_sources(
1857         ctx=ctx or FakeContext(),
1858         src=gs_src,
1859         quiet=False,
1860         verbose=False,
1861         include=gs_include,
1862         exclude=gs_exclude,
1863         extend_exclude=gs_extend_exclude,
1864         force_exclude=gs_force_exclude,
1865         report=black.Report(),
1866         stdin_filename=stdin_filename,
1867     )
1868     assert sorted(collected) == sorted(gs_expected)
1869
1870
1871 class TestFileCollection:
1872     def test_include_exclude(self) -> None:
1873         path = THIS_DIR / "data" / "include_exclude_tests"
1874         src = [path]
1875         expected = [
1876             Path(path / "b/dont_exclude/a.py"),
1877             Path(path / "b/dont_exclude/a.pyi"),
1878         ]
1879         assert_collected_sources(
1880             src,
1881             expected,
1882             include=r"\.pyi?$",
1883             exclude=r"/exclude/|/\.definitely_exclude/",
1884         )
1885
1886     def test_gitignore_used_as_default(self) -> None:
1887         base = Path(DATA_DIR / "include_exclude_tests")
1888         expected = [
1889             base / "b/.definitely_exclude/a.py",
1890             base / "b/.definitely_exclude/a.pyi",
1891         ]
1892         src = [base / "b/"]
1893         ctx = FakeContext()
1894         ctx.obj["root"] = base
1895         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1896
1897     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1898     def test_exclude_for_issue_1572(self) -> None:
1899         # Exclude shouldn't touch files that were explicitly given to Black through the
1900         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1901         # https://github.com/psf/black/issues/1572
1902         path = DATA_DIR / "include_exclude_tests"
1903         src = [path / "b/exclude/a.py"]
1904         expected = [path / "b/exclude/a.py"]
1905         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1906
1907     def test_gitignore_exclude(self) -> None:
1908         path = THIS_DIR / "data" / "include_exclude_tests"
1909         include = re.compile(r"\.pyi?$")
1910         exclude = re.compile(r"")
1911         report = black.Report()
1912         gitignore = PathSpec.from_lines(
1913             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1914         )
1915         sources: List[Path] = []
1916         expected = [
1917             Path(path / "b/dont_exclude/a.py"),
1918             Path(path / "b/dont_exclude/a.pyi"),
1919         ]
1920         this_abs = THIS_DIR.resolve()
1921         sources.extend(
1922             black.gen_python_files(
1923                 path.iterdir(),
1924                 this_abs,
1925                 include,
1926                 exclude,
1927                 None,
1928                 None,
1929                 report,
1930                 gitignore,
1931                 verbose=False,
1932                 quiet=False,
1933             )
1934         )
1935         assert sorted(expected) == sorted(sources)
1936
1937     def test_nested_gitignore(self) -> None:
1938         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1939         include = re.compile(r"\.pyi?$")
1940         exclude = re.compile(r"")
1941         root_gitignore = black.files.get_gitignore(path)
1942         report = black.Report()
1943         expected: List[Path] = [
1944             Path(path / "x.py"),
1945             Path(path / "root/b.py"),
1946             Path(path / "root/c.py"),
1947             Path(path / "root/child/c.py"),
1948         ]
1949         this_abs = THIS_DIR.resolve()
1950         sources = list(
1951             black.gen_python_files(
1952                 path.iterdir(),
1953                 this_abs,
1954                 include,
1955                 exclude,
1956                 None,
1957                 None,
1958                 report,
1959                 root_gitignore,
1960                 verbose=False,
1961                 quiet=False,
1962             )
1963         )
1964         assert sorted(expected) == sorted(sources)
1965
1966     def test_invalid_gitignore(self) -> None:
1967         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1968         empty_config = path / "pyproject.toml"
1969         result = BlackRunner().invoke(
1970             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1971         )
1972         assert result.exit_code == 1
1973         assert result.stderr_bytes is not None
1974
1975         gitignore = path / ".gitignore"
1976         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1977
1978     def test_invalid_nested_gitignore(self) -> None:
1979         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1980         empty_config = path / "pyproject.toml"
1981         result = BlackRunner().invoke(
1982             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1983         )
1984         assert result.exit_code == 1
1985         assert result.stderr_bytes is not None
1986
1987         gitignore = path / "a" / ".gitignore"
1988         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1989
1990     def test_empty_include(self) -> None:
1991         path = DATA_DIR / "include_exclude_tests"
1992         src = [path]
1993         expected = [
1994             Path(path / "b/exclude/a.pie"),
1995             Path(path / "b/exclude/a.py"),
1996             Path(path / "b/exclude/a.pyi"),
1997             Path(path / "b/dont_exclude/a.pie"),
1998             Path(path / "b/dont_exclude/a.py"),
1999             Path(path / "b/dont_exclude/a.pyi"),
2000             Path(path / "b/.definitely_exclude/a.pie"),
2001             Path(path / "b/.definitely_exclude/a.py"),
2002             Path(path / "b/.definitely_exclude/a.pyi"),
2003             Path(path / ".gitignore"),
2004             Path(path / "pyproject.toml"),
2005         ]
2006         # Setting exclude explicitly to an empty string to block .gitignore usage.
2007         assert_collected_sources(src, expected, include="", exclude="")
2008
2009     def test_extend_exclude(self) -> None:
2010         path = DATA_DIR / "include_exclude_tests"
2011         src = [path]
2012         expected = [
2013             Path(path / "b/exclude/a.py"),
2014             Path(path / "b/dont_exclude/a.py"),
2015         ]
2016         assert_collected_sources(
2017             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2018         )
2019
2020     @pytest.mark.incompatible_with_mypyc
2021     def test_symlink_out_of_root_directory(self) -> None:
2022         path = MagicMock()
2023         root = THIS_DIR.resolve()
2024         child = MagicMock()
2025         include = re.compile(black.DEFAULT_INCLUDES)
2026         exclude = re.compile(black.DEFAULT_EXCLUDES)
2027         report = black.Report()
2028         gitignore = PathSpec.from_lines("gitwildmatch", [])
2029         # `child` should behave like a symlink which resolved path is clearly
2030         # outside of the `root` directory.
2031         path.iterdir.return_value = [child]
2032         child.resolve.return_value = Path("/a/b/c")
2033         child.as_posix.return_value = "/a/b/c"
2034         try:
2035             list(
2036                 black.gen_python_files(
2037                     path.iterdir(),
2038                     root,
2039                     include,
2040                     exclude,
2041                     None,
2042                     None,
2043                     report,
2044                     gitignore,
2045                     verbose=False,
2046                     quiet=False,
2047                 )
2048             )
2049         except ValueError as ve:
2050             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2051         path.iterdir.assert_called_once()
2052         child.resolve.assert_called_once()
2053
2054     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2055     def test_get_sources_with_stdin(self) -> None:
2056         src = ["-"]
2057         expected = ["-"]
2058         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2059
2060     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2061     def test_get_sources_with_stdin_filename(self) -> None:
2062         src = ["-"]
2063         stdin_filename = str(THIS_DIR / "data/collections.py")
2064         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2065         assert_collected_sources(
2066             src,
2067             expected,
2068             exclude=r"/exclude/a\.py",
2069             stdin_filename=stdin_filename,
2070         )
2071
2072     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2073     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2074         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2075         # file being passed directly. This is the same as
2076         # test_exclude_for_issue_1572
2077         path = DATA_DIR / "include_exclude_tests"
2078         src = ["-"]
2079         stdin_filename = str(path / "b/exclude/a.py")
2080         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2081         assert_collected_sources(
2082             src,
2083             expected,
2084             exclude=r"/exclude/|a\.py",
2085             stdin_filename=stdin_filename,
2086         )
2087
2088     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2089     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2090         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2091         # file being passed directly. This is the same as
2092         # test_exclude_for_issue_1572
2093         src = ["-"]
2094         path = THIS_DIR / "data" / "include_exclude_tests"
2095         stdin_filename = str(path / "b/exclude/a.py")
2096         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2097         assert_collected_sources(
2098             src,
2099             expected,
2100             extend_exclude=r"/exclude/|a\.py",
2101             stdin_filename=stdin_filename,
2102         )
2103
2104     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2105     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2106         # Force exclude should exclude the file when passing it through
2107         # stdin_filename
2108         path = THIS_DIR / "data" / "include_exclude_tests"
2109         stdin_filename = str(path / "b/exclude/a.py")
2110         assert_collected_sources(
2111             src=["-"],
2112             expected=[],
2113             force_exclude=r"/exclude/|a\.py",
2114             stdin_filename=stdin_filename,
2115         )
2116
2117
2118 try:
2119     with open(black.__file__, "r", encoding="utf-8") as _bf:
2120         black_source_lines = _bf.readlines()
2121 except UnicodeDecodeError:
2122     if not black.COMPILED:
2123         raise
2124
2125
2126 def tracefunc(
2127     frame: types.FrameType, event: str, arg: Any
2128 ) -> Callable[[types.FrameType, str, Any], Any]:
2129     """Show function calls `from black/__init__.py` as they happen.
2130
2131     Register this with `sys.settrace()` in a test you're debugging.
2132     """
2133     if event != "call":
2134         return tracefunc
2135
2136     stack = len(inspect.stack()) - 19
2137     stack *= 2
2138     filename = frame.f_code.co_filename
2139     lineno = frame.f_lineno
2140     func_sig_lineno = lineno - 1
2141     funcname = black_source_lines[func_sig_lineno].strip()
2142     while funcname.startswith("@"):
2143         func_sig_lineno += 1
2144         funcname = black_source_lines[func_sig_lineno].strip()
2145     if "black/__init__.py" in filename:
2146         print(f"{' ' * stack}{lineno}:{funcname}")
2147     return tracefunc