]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use is_number_token instead of assertion (#3069)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63     get_case_path,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     @patch("black.dump_to_file", dump_to_stderr)
314     def test_string_quotes(self) -> None:
315         source, expected = read_data("miscellaneous", "string_quotes")
316         mode = black.Mode(preview=True)
317         assert_format(source, expected, mode)
318         mode = replace(mode, string_normalization=False)
319         not_normalized = fs(source, mode=mode)
320         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
321         black.assert_equivalent(source, not_normalized)
322         black.assert_stable(source, not_normalized, mode=mode)
323
324     def test_skip_magic_trailing_comma(self) -> None:
325         source, _ = read_data("simple_cases", "expression")
326         expected, _ = read_data(
327             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
328         )
329         tmp_file = Path(black.dump_to_file(source))
330         diff_header = re.compile(
331             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
332             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
333         )
334         try:
335             result = BlackRunner().invoke(
336                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
337             )
338             self.assertEqual(result.exit_code, 0)
339         finally:
340             os.unlink(tmp_file)
341         actual = result.output
342         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
343         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
344         if expected != actual:
345             dump = black.dump_to_file(actual)
346             msg = (
347                 "Expected diff isn't equal to the actual. If you made changes to"
348                 " expression.py and this is an anticipated difference, overwrite"
349                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
350             )
351             self.assertEqual(expected, actual, msg)
352
353     @patch("black.dump_to_file", dump_to_stderr)
354     def test_async_as_identifier(self) -> None:
355         source_path = get_case_path("miscellaneous", "async_as_identifier")
356         source, expected = read_data_from_file(source_path)
357         actual = fs(source)
358         self.assertFormatEqual(expected, actual)
359         major, minor = sys.version_info[:2]
360         if major < 3 or (major <= 3 and minor < 7):
361             black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, DEFAULT_MODE)
363         # ensure black can parse this when the target is 3.6
364         self.invokeBlack([str(source_path), "--target-version", "py36"])
365         # but not on 3.7, because async/await is no longer an identifier
366         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_python37(self) -> None:
370         source_path = get_case_path("py_37", "python37")
371         source, expected = read_data_from_file(source_path)
372         actual = fs(source)
373         self.assertFormatEqual(expected, actual)
374         major, minor = sys.version_info[:2]
375         if major > 3 or (major == 3 and minor >= 7):
376             black.assert_equivalent(source, actual)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378         # ensure black can parse this when the target is 3.7
379         self.invokeBlack([str(source_path), "--target-version", "py37"])
380         # but not on 3.6, because we use async as a reserved keyword
381         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
382
383     def test_tab_comment_indentation(self) -> None:
384         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
385         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
386         self.assertFormatEqual(contents_spc, fs(contents_spc))
387         self.assertFormatEqual(contents_spc, fs(contents_tab))
388
389         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
390         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
391         self.assertFormatEqual(contents_spc, fs(contents_spc))
392         self.assertFormatEqual(contents_spc, fs(contents_tab))
393
394         # mixed tabs and spaces (valid Python 2 code)
395         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
396         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
397         self.assertFormatEqual(contents_spc, fs(contents_spc))
398         self.assertFormatEqual(contents_spc, fs(contents_tab))
399
400         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
401         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
402         self.assertFormatEqual(contents_spc, fs(contents_spc))
403         self.assertFormatEqual(contents_spc, fs(contents_tab))
404
405     def test_report_verbose(self) -> None:
406         report = Report(verbose=True)
407         out_lines = []
408         err_lines = []
409
410         def out(msg: str, **kwargs: Any) -> None:
411             out_lines.append(msg)
412
413         def err(msg: str, **kwargs: Any) -> None:
414             err_lines.append(msg)
415
416         with patch("black.output._out", out), patch("black.output._err", err):
417             report.done(Path("f1"), black.Changed.NO)
418             self.assertEqual(len(out_lines), 1)
419             self.assertEqual(len(err_lines), 0)
420             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
421             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
422             self.assertEqual(report.return_code, 0)
423             report.done(Path("f2"), black.Changed.YES)
424             self.assertEqual(len(out_lines), 2)
425             self.assertEqual(len(err_lines), 0)
426             self.assertEqual(out_lines[-1], "reformatted f2")
427             self.assertEqual(
428                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
429             )
430             report.done(Path("f3"), black.Changed.CACHED)
431             self.assertEqual(len(out_lines), 3)
432             self.assertEqual(len(err_lines), 0)
433             self.assertEqual(
434                 out_lines[-1], "f3 wasn't modified on disk since last run."
435             )
436             self.assertEqual(
437                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
438             )
439             self.assertEqual(report.return_code, 0)
440             report.check = True
441             self.assertEqual(report.return_code, 1)
442             report.check = False
443             report.failed(Path("e1"), "boom")
444             self.assertEqual(len(out_lines), 3)
445             self.assertEqual(len(err_lines), 1)
446             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
447             self.assertEqual(
448                 unstyle(str(report)),
449                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
450                 " reformat.",
451             )
452             self.assertEqual(report.return_code, 123)
453             report.done(Path("f3"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 4)
455             self.assertEqual(len(err_lines), 1)
456             self.assertEqual(out_lines[-1], "reformatted f3")
457             self.assertEqual(
458                 unstyle(str(report)),
459                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
460                 " reformat.",
461             )
462             self.assertEqual(report.return_code, 123)
463             report.failed(Path("e2"), "boom")
464             self.assertEqual(len(out_lines), 4)
465             self.assertEqual(len(err_lines), 2)
466             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
467             self.assertEqual(
468                 unstyle(str(report)),
469                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
470                 " reformat.",
471             )
472             self.assertEqual(report.return_code, 123)
473             report.path_ignored(Path("wat"), "no match")
474             self.assertEqual(len(out_lines), 5)
475             self.assertEqual(len(err_lines), 2)
476             self.assertEqual(out_lines[-1], "wat ignored: no match")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f4"), black.Changed.NO)
484             self.assertEqual(len(out_lines), 6)
485             self.assertEqual(len(err_lines), 2)
486             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.check = True
494             self.assertEqual(
495                 unstyle(str(report)),
496                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
497                 " would fail to reformat.",
498             )
499             report.check = False
500             report.diff = True
501             self.assertEqual(
502                 unstyle(str(report)),
503                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
504                 " would fail to reformat.",
505             )
506
507     def test_report_quiet(self) -> None:
508         report = Report(quiet=True)
509         out_lines = []
510         err_lines = []
511
512         def out(msg: str, **kwargs: Any) -> None:
513             out_lines.append(msg)
514
515         def err(msg: str, **kwargs: Any) -> None:
516             err_lines.append(msg)
517
518         with patch("black.output._out", out), patch("black.output._err", err):
519             report.done(Path("f1"), black.Changed.NO)
520             self.assertEqual(len(out_lines), 0)
521             self.assertEqual(len(err_lines), 0)
522             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
523             self.assertEqual(report.return_code, 0)
524             report.done(Path("f2"), black.Changed.YES)
525             self.assertEqual(len(out_lines), 0)
526             self.assertEqual(len(err_lines), 0)
527             self.assertEqual(
528                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
529             )
530             report.done(Path("f3"), black.Changed.CACHED)
531             self.assertEqual(len(out_lines), 0)
532             self.assertEqual(len(err_lines), 0)
533             self.assertEqual(
534                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
535             )
536             self.assertEqual(report.return_code, 0)
537             report.check = True
538             self.assertEqual(report.return_code, 1)
539             report.check = False
540             report.failed(Path("e1"), "boom")
541             self.assertEqual(len(out_lines), 0)
542             self.assertEqual(len(err_lines), 1)
543             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
544             self.assertEqual(
545                 unstyle(str(report)),
546                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
547                 " reformat.",
548             )
549             self.assertEqual(report.return_code, 123)
550             report.done(Path("f3"), black.Changed.YES)
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 1)
553             self.assertEqual(
554                 unstyle(str(report)),
555                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
556                 " reformat.",
557             )
558             self.assertEqual(report.return_code, 123)
559             report.failed(Path("e2"), "boom")
560             self.assertEqual(len(out_lines), 0)
561             self.assertEqual(len(err_lines), 2)
562             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
563             self.assertEqual(
564                 unstyle(str(report)),
565                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
566                 " reformat.",
567             )
568             self.assertEqual(report.return_code, 123)
569             report.path_ignored(Path("wat"), "no match")
570             self.assertEqual(len(out_lines), 0)
571             self.assertEqual(len(err_lines), 2)
572             self.assertEqual(
573                 unstyle(str(report)),
574                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
575                 " reformat.",
576             )
577             self.assertEqual(report.return_code, 123)
578             report.done(Path("f4"), black.Changed.NO)
579             self.assertEqual(len(out_lines), 0)
580             self.assertEqual(len(err_lines), 2)
581             self.assertEqual(
582                 unstyle(str(report)),
583                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
584                 " reformat.",
585             )
586             self.assertEqual(report.return_code, 123)
587             report.check = True
588             self.assertEqual(
589                 unstyle(str(report)),
590                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
591                 " would fail to reformat.",
592             )
593             report.check = False
594             report.diff = True
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
598                 " would fail to reformat.",
599             )
600
601     def test_report_normal(self) -> None:
602         report = black.Report()
603         out_lines = []
604         err_lines = []
605
606         def out(msg: str, **kwargs: Any) -> None:
607             out_lines.append(msg)
608
609         def err(msg: str, **kwargs: Any) -> None:
610             err_lines.append(msg)
611
612         with patch("black.output._out", out), patch("black.output._err", err):
613             report.done(Path("f1"), black.Changed.NO)
614             self.assertEqual(len(out_lines), 0)
615             self.assertEqual(len(err_lines), 0)
616             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
617             self.assertEqual(report.return_code, 0)
618             report.done(Path("f2"), black.Changed.YES)
619             self.assertEqual(len(out_lines), 1)
620             self.assertEqual(len(err_lines), 0)
621             self.assertEqual(out_lines[-1], "reformatted f2")
622             self.assertEqual(
623                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
624             )
625             report.done(Path("f3"), black.Changed.CACHED)
626             self.assertEqual(len(out_lines), 1)
627             self.assertEqual(len(err_lines), 0)
628             self.assertEqual(out_lines[-1], "reformatted f2")
629             self.assertEqual(
630                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
631             )
632             self.assertEqual(report.return_code, 0)
633             report.check = True
634             self.assertEqual(report.return_code, 1)
635             report.check = False
636             report.failed(Path("e1"), "boom")
637             self.assertEqual(len(out_lines), 1)
638             self.assertEqual(len(err_lines), 1)
639             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
643                 " reformat.",
644             )
645             self.assertEqual(report.return_code, 123)
646             report.done(Path("f3"), black.Changed.YES)
647             self.assertEqual(len(out_lines), 2)
648             self.assertEqual(len(err_lines), 1)
649             self.assertEqual(out_lines[-1], "reformatted f3")
650             self.assertEqual(
651                 unstyle(str(report)),
652                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
653                 " reformat.",
654             )
655             self.assertEqual(report.return_code, 123)
656             report.failed(Path("e2"), "boom")
657             self.assertEqual(len(out_lines), 2)
658             self.assertEqual(len(err_lines), 2)
659             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.path_ignored(Path("wat"), "no match")
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 2)
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
672                 " reformat.",
673             )
674             self.assertEqual(report.return_code, 123)
675             report.done(Path("f4"), black.Changed.NO)
676             self.assertEqual(len(out_lines), 2)
677             self.assertEqual(len(err_lines), 2)
678             self.assertEqual(
679                 unstyle(str(report)),
680                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
681                 " reformat.",
682             )
683             self.assertEqual(report.return_code, 123)
684             report.check = True
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
688                 " would fail to reformat.",
689             )
690             report.check = False
691             report.diff = True
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
695                 " would fail to reformat.",
696             )
697
698     def test_lib2to3_parse(self) -> None:
699         with self.assertRaises(black.InvalidInput):
700             black.lib2to3_parse("invalid syntax")
701
702         straddling = "x + y"
703         black.lib2to3_parse(straddling)
704         black.lib2to3_parse(straddling, {TargetVersion.PY36})
705
706         py2_only = "print x"
707         with self.assertRaises(black.InvalidInput):
708             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
709
710         py3_only = "exec(x, end=y)"
711         black.lib2to3_parse(py3_only)
712         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
713
714     def test_get_features_used_decorator(self) -> None:
715         # Test the feature detection of new decorator syntax
716         # since this makes some test cases of test_get_features_used()
717         # fails if it fails, this is tested first so that a useful case
718         # is identified
719         simples, relaxed = read_data("miscellaneous", "decorators")
720         # skip explanation comments at the top of the file
721         for simple_test in simples.split("##")[1:]:
722             node = black.lib2to3_parse(simple_test)
723             decorator = str(node.children[0].children[0]).strip()
724             self.assertNotIn(
725                 Feature.RELAXED_DECORATORS,
726                 black.get_features_used(node),
727                 msg=(
728                     f"decorator '{decorator}' follows python<=3.8 syntax"
729                     "but is detected as 3.9+"
730                     # f"The full node is\n{node!r}"
731                 ),
732             )
733         # skip the '# output' comment at the top of the output part
734         for relaxed_test in relaxed.split("##")[1:]:
735             node = black.lib2to3_parse(relaxed_test)
736             decorator = str(node.children[0].children[0]).strip()
737             self.assertIn(
738                 Feature.RELAXED_DECORATORS,
739                 black.get_features_used(node),
740                 msg=(
741                     f"decorator '{decorator}' uses python3.9+ syntax"
742                     "but is detected as python<=3.8"
743                     # f"The full node is\n{node!r}"
744                 ),
745             )
746
747     def test_get_features_used(self) -> None:
748         node = black.lib2to3_parse("def f(*, arg): ...\n")
749         self.assertEqual(black.get_features_used(node), set())
750         node = black.lib2to3_parse("def f(*, arg,): ...\n")
751         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
752         node = black.lib2to3_parse("f(*arg,)\n")
753         self.assertEqual(
754             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
755         )
756         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
757         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
758         node = black.lib2to3_parse("123_456\n")
759         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
760         node = black.lib2to3_parse("123456\n")
761         self.assertEqual(black.get_features_used(node), set())
762         source, expected = read_data("simple_cases", "function")
763         node = black.lib2to3_parse(source)
764         expected_features = {
765             Feature.TRAILING_COMMA_IN_CALL,
766             Feature.TRAILING_COMMA_IN_DEF,
767             Feature.F_STRINGS,
768         }
769         self.assertEqual(black.get_features_used(node), expected_features)
770         node = black.lib2to3_parse(expected)
771         self.assertEqual(black.get_features_used(node), expected_features)
772         source, expected = read_data("simple_cases", "expression")
773         node = black.lib2to3_parse(source)
774         self.assertEqual(black.get_features_used(node), set())
775         node = black.lib2to3_parse(expected)
776         self.assertEqual(black.get_features_used(node), set())
777         node = black.lib2to3_parse("lambda a, /, b: ...")
778         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
779         node = black.lib2to3_parse("def fn(a, /, b): ...")
780         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
781         node = black.lib2to3_parse("def fn(): yield a, b")
782         self.assertEqual(black.get_features_used(node), set())
783         node = black.lib2to3_parse("def fn(): return a, b")
784         self.assertEqual(black.get_features_used(node), set())
785         node = black.lib2to3_parse("def fn(): yield *b, c")
786         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
787         node = black.lib2to3_parse("def fn(): return a, *b, c")
788         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
789         node = black.lib2to3_parse("x = a, *b, c")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = regular")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Any = (regular, regular)")
794         self.assertEqual(black.get_features_used(node), set())
795         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
796         self.assertEqual(black.get_features_used(node), set())
797         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
798         self.assertEqual(
799             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
800         )
801         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
802         self.assertEqual(black.get_features_used(node), set())
803         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
806         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
807         node = black.lib2to3_parse("a[*b]")
808         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
809         node = black.lib2to3_parse("a[x, *y(), z] = t")
810         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
811         node = black.lib2to3_parse("def fn(*args: *T): pass")
812         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
813
814     def test_get_features_used_for_future_flags(self) -> None:
815         for src, features in [
816             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
817             (
818                 "from __future__ import (other, annotations)",
819                 {Feature.FUTURE_ANNOTATIONS},
820             ),
821             ("a = 1 + 2\nfrom something import annotations", set()),
822             ("from __future__ import x, y", set()),
823         ]:
824             with self.subTest(src=src, features=features):
825                 node = black.lib2to3_parse(src)
826                 future_imports = black.get_future_imports(node)
827                 self.assertEqual(
828                     black.get_features_used(node, future_imports=future_imports),
829                     features,
830                 )
831
832     def test_get_future_imports(self) -> None:
833         node = black.lib2to3_parse("\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from __future__ import black\n")
836         self.assertEqual({"black"}, black.get_future_imports(node))
837         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
838         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
839         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
840         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import multiple\nfrom __future__ import imports\n"
843         )
844         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
846         self.assertEqual({"black"}, black.get_future_imports(node))
847         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
848         self.assertEqual({"black"}, black.get_future_imports(node))
849         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
850         self.assertEqual(set(), black.get_future_imports(node))
851         node = black.lib2to3_parse("from some.module import black\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse(
854             "from __future__ import unicode_literals as _unicode_literals"
855         )
856         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
857         node = black.lib2to3_parse(
858             "from __future__ import unicode_literals as _lol, print"
859         )
860         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
861
862     @pytest.mark.incompatible_with_mypyc
863     def test_debug_visitor(self) -> None:
864         source, _ = read_data("miscellaneous", "debug_visitor")
865         expected, _ = read_data("miscellaneous", "debug_visitor.out")
866         out_lines = []
867         err_lines = []
868
869         def out(msg: str, **kwargs: Any) -> None:
870             out_lines.append(msg)
871
872         def err(msg: str, **kwargs: Any) -> None:
873             err_lines.append(msg)
874
875         with patch("black.debug.out", out):
876             DebugVisitor.show(source)
877         actual = "\n".join(out_lines) + "\n"
878         log_name = ""
879         if expected != actual:
880             log_name = black.dump_to_file(*out_lines)
881         self.assertEqual(
882             expected,
883             actual,
884             f"AST print out is different. Actual version dumped to {log_name}",
885         )
886
887     def test_format_file_contents(self) -> None:
888         empty = ""
889         mode = DEFAULT_MODE
890         with self.assertRaises(black.NothingChanged):
891             black.format_file_contents(empty, mode=mode, fast=False)
892         just_nl = "\n"
893         with self.assertRaises(black.NothingChanged):
894             black.format_file_contents(just_nl, mode=mode, fast=False)
895         same = "j = [1, 2, 3]\n"
896         with self.assertRaises(black.NothingChanged):
897             black.format_file_contents(same, mode=mode, fast=False)
898         different = "j = [1,2,3]"
899         expected = same
900         actual = black.format_file_contents(different, mode=mode, fast=False)
901         self.assertEqual(expected, actual)
902         invalid = "return if you can"
903         with self.assertRaises(black.InvalidInput) as e:
904             black.format_file_contents(invalid, mode=mode, fast=False)
905         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
906
907     def test_endmarker(self) -> None:
908         n = black.lib2to3_parse("\n")
909         self.assertEqual(n.type, black.syms.file_input)
910         self.assertEqual(len(n.children), 1)
911         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
912
913     @pytest.mark.incompatible_with_mypyc
914     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
915     def test_assertFormatEqual(self) -> None:
916         out_lines = []
917         err_lines = []
918
919         def out(msg: str, **kwargs: Any) -> None:
920             out_lines.append(msg)
921
922         def err(msg: str, **kwargs: Any) -> None:
923             err_lines.append(msg)
924
925         with patch("black.output._out", out), patch("black.output._err", err):
926             with self.assertRaises(AssertionError):
927                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
928
929         out_str = "".join(out_lines)
930         self.assertIn("Expected tree:", out_str)
931         self.assertIn("Actual tree:", out_str)
932         self.assertEqual("".join(err_lines), "")
933
934     @event_loop()
935     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
936     def test_works_in_mono_process_only_environment(self) -> None:
937         with cache_dir() as workspace:
938             for f in [
939                 (workspace / "one.py").resolve(),
940                 (workspace / "two.py").resolve(),
941             ]:
942                 f.write_text('print("hello")\n')
943             self.invokeBlack([str(workspace)])
944
945     @event_loop()
946     def test_check_diff_use_together(self) -> None:
947         with cache_dir():
948             # Files which will be reformatted.
949             src1 = get_case_path("miscellaneous", "string_quotes")
950             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
951             # Files which will not be reformatted.
952             src2 = get_case_path("simple_cases", "composition")
953             self.invokeBlack([str(src2), "--diff", "--check"])
954             # Multi file command.
955             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
956
957     def test_no_src_fails(self) -> None:
958         with cache_dir():
959             self.invokeBlack([], exit_code=1)
960
961     def test_src_and_code_fails(self) -> None:
962         with cache_dir():
963             self.invokeBlack([".", "-c", "0"], exit_code=1)
964
965     def test_broken_symlink(self) -> None:
966         with cache_dir() as workspace:
967             symlink = workspace / "broken_link.py"
968             try:
969                 symlink.symlink_to("nonexistent.py")
970             except (OSError, NotImplementedError) as e:
971                 self.skipTest(f"Can't create symlinks: {e}")
972             self.invokeBlack([str(workspace.resolve())])
973
974     def test_single_file_force_pyi(self) -> None:
975         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
976         contents, expected = read_data("miscellaneous", "force_pyi")
977         with cache_dir() as workspace:
978             path = (workspace / "file.py").resolve()
979             with open(path, "w") as fh:
980                 fh.write(contents)
981             self.invokeBlack([str(path), "--pyi"])
982             with open(path, "r") as fh:
983                 actual = fh.read()
984             # verify cache with --pyi is separate
985             pyi_cache = black.read_cache(pyi_mode)
986             self.assertIn(str(path), pyi_cache)
987             normal_cache = black.read_cache(DEFAULT_MODE)
988             self.assertNotIn(str(path), normal_cache)
989         self.assertFormatEqual(expected, actual)
990         black.assert_equivalent(contents, actual)
991         black.assert_stable(contents, actual, pyi_mode)
992
993     @event_loop()
994     def test_multi_file_force_pyi(self) -> None:
995         reg_mode = DEFAULT_MODE
996         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
997         contents, expected = read_data("miscellaneous", "force_pyi")
998         with cache_dir() as workspace:
999             paths = [
1000                 (workspace / "file1.py").resolve(),
1001                 (workspace / "file2.py").resolve(),
1002             ]
1003             for path in paths:
1004                 with open(path, "w") as fh:
1005                     fh.write(contents)
1006             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1007             for path in paths:
1008                 with open(path, "r") as fh:
1009                     actual = fh.read()
1010                 self.assertEqual(actual, expected)
1011             # verify cache with --pyi is separate
1012             pyi_cache = black.read_cache(pyi_mode)
1013             normal_cache = black.read_cache(reg_mode)
1014             for path in paths:
1015                 self.assertIn(str(path), pyi_cache)
1016                 self.assertNotIn(str(path), normal_cache)
1017
1018     def test_pipe_force_pyi(self) -> None:
1019         source, expected = read_data("miscellaneous", "force_pyi")
1020         result = CliRunner().invoke(
1021             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1022         )
1023         self.assertEqual(result.exit_code, 0)
1024         actual = result.output
1025         self.assertFormatEqual(actual, expected)
1026
1027     def test_single_file_force_py36(self) -> None:
1028         reg_mode = DEFAULT_MODE
1029         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1030         source, expected = read_data("miscellaneous", "force_py36")
1031         with cache_dir() as workspace:
1032             path = (workspace / "file.py").resolve()
1033             with open(path, "w") as fh:
1034                 fh.write(source)
1035             self.invokeBlack([str(path), *PY36_ARGS])
1036             with open(path, "r") as fh:
1037                 actual = fh.read()
1038             # verify cache with --target-version is separate
1039             py36_cache = black.read_cache(py36_mode)
1040             self.assertIn(str(path), py36_cache)
1041             normal_cache = black.read_cache(reg_mode)
1042             self.assertNotIn(str(path), normal_cache)
1043         self.assertEqual(actual, expected)
1044
1045     @event_loop()
1046     def test_multi_file_force_py36(self) -> None:
1047         reg_mode = DEFAULT_MODE
1048         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1049         source, expected = read_data("miscellaneous", "force_py36")
1050         with cache_dir() as workspace:
1051             paths = [
1052                 (workspace / "file1.py").resolve(),
1053                 (workspace / "file2.py").resolve(),
1054             ]
1055             for path in paths:
1056                 with open(path, "w") as fh:
1057                     fh.write(source)
1058             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1059             for path in paths:
1060                 with open(path, "r") as fh:
1061                     actual = fh.read()
1062                 self.assertEqual(actual, expected)
1063             # verify cache with --target-version is separate
1064             pyi_cache = black.read_cache(py36_mode)
1065             normal_cache = black.read_cache(reg_mode)
1066             for path in paths:
1067                 self.assertIn(str(path), pyi_cache)
1068                 self.assertNotIn(str(path), normal_cache)
1069
1070     def test_pipe_force_py36(self) -> None:
1071         source, expected = read_data("miscellaneous", "force_py36")
1072         result = CliRunner().invoke(
1073             black.main,
1074             ["-", "-q", "--target-version=py36"],
1075             input=BytesIO(source.encode("utf8")),
1076         )
1077         self.assertEqual(result.exit_code, 0)
1078         actual = result.output
1079         self.assertFormatEqual(actual, expected)
1080
1081     @pytest.mark.incompatible_with_mypyc
1082     def test_reformat_one_with_stdin(self) -> None:
1083         with patch(
1084             "black.format_stdin_to_stdout",
1085             return_value=lambda *args, **kwargs: black.Changed.YES,
1086         ) as fsts:
1087             report = MagicMock()
1088             path = Path("-")
1089             black.reformat_one(
1090                 path,
1091                 fast=True,
1092                 write_back=black.WriteBack.YES,
1093                 mode=DEFAULT_MODE,
1094                 report=report,
1095             )
1096             fsts.assert_called_once()
1097             report.done.assert_called_with(path, black.Changed.YES)
1098
1099     @pytest.mark.incompatible_with_mypyc
1100     def test_reformat_one_with_stdin_filename(self) -> None:
1101         with patch(
1102             "black.format_stdin_to_stdout",
1103             return_value=lambda *args, **kwargs: black.Changed.YES,
1104         ) as fsts:
1105             report = MagicMock()
1106             p = "foo.py"
1107             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1108             expected = Path(p)
1109             black.reformat_one(
1110                 path,
1111                 fast=True,
1112                 write_back=black.WriteBack.YES,
1113                 mode=DEFAULT_MODE,
1114                 report=report,
1115             )
1116             fsts.assert_called_once_with(
1117                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1118             )
1119             # __BLACK_STDIN_FILENAME__ should have been stripped
1120             report.done.assert_called_with(expected, black.Changed.YES)
1121
1122     @pytest.mark.incompatible_with_mypyc
1123     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1124         with patch(
1125             "black.format_stdin_to_stdout",
1126             return_value=lambda *args, **kwargs: black.Changed.YES,
1127         ) as fsts:
1128             report = MagicMock()
1129             p = "foo.pyi"
1130             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1131             expected = Path(p)
1132             black.reformat_one(
1133                 path,
1134                 fast=True,
1135                 write_back=black.WriteBack.YES,
1136                 mode=DEFAULT_MODE,
1137                 report=report,
1138             )
1139             fsts.assert_called_once_with(
1140                 fast=True,
1141                 write_back=black.WriteBack.YES,
1142                 mode=replace(DEFAULT_MODE, is_pyi=True),
1143             )
1144             # __BLACK_STDIN_FILENAME__ should have been stripped
1145             report.done.assert_called_with(expected, black.Changed.YES)
1146
1147     @pytest.mark.incompatible_with_mypyc
1148     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1149         with patch(
1150             "black.format_stdin_to_stdout",
1151             return_value=lambda *args, **kwargs: black.Changed.YES,
1152         ) as fsts:
1153             report = MagicMock()
1154             p = "foo.ipynb"
1155             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1156             expected = Path(p)
1157             black.reformat_one(
1158                 path,
1159                 fast=True,
1160                 write_back=black.WriteBack.YES,
1161                 mode=DEFAULT_MODE,
1162                 report=report,
1163             )
1164             fsts.assert_called_once_with(
1165                 fast=True,
1166                 write_back=black.WriteBack.YES,
1167                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1168             )
1169             # __BLACK_STDIN_FILENAME__ should have been stripped
1170             report.done.assert_called_with(expected, black.Changed.YES)
1171
1172     @pytest.mark.incompatible_with_mypyc
1173     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1174         with patch(
1175             "black.format_stdin_to_stdout",
1176             return_value=lambda *args, **kwargs: black.Changed.YES,
1177         ) as fsts:
1178             report = MagicMock()
1179             # Even with an existing file, since we are forcing stdin, black
1180             # should output to stdout and not modify the file inplace
1181             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1182             # Make sure is_file actually returns True
1183             self.assertTrue(p.is_file())
1184             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1185             expected = Path(p)
1186             black.reformat_one(
1187                 path,
1188                 fast=True,
1189                 write_back=black.WriteBack.YES,
1190                 mode=DEFAULT_MODE,
1191                 report=report,
1192             )
1193             fsts.assert_called_once()
1194             # __BLACK_STDIN_FILENAME__ should have been stripped
1195             report.done.assert_called_with(expected, black.Changed.YES)
1196
1197     def test_reformat_one_with_stdin_empty(self) -> None:
1198         output = io.StringIO()
1199         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1200             try:
1201                 black.format_stdin_to_stdout(
1202                     fast=True,
1203                     content="",
1204                     write_back=black.WriteBack.YES,
1205                     mode=DEFAULT_MODE,
1206                 )
1207             except io.UnsupportedOperation:
1208                 pass  # StringIO does not support detach
1209             assert output.getvalue() == ""
1210
1211     def test_invalid_cli_regex(self) -> None:
1212         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1213             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1214
1215     def test_required_version_matches_version(self) -> None:
1216         self.invokeBlack(
1217             ["--required-version", black.__version__, "-c", "0"],
1218             exit_code=0,
1219             ignore_config=True,
1220         )
1221
1222     def test_required_version_matches_partial_version(self) -> None:
1223         self.invokeBlack(
1224             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1225             exit_code=0,
1226             ignore_config=True,
1227         )
1228
1229     def test_required_version_does_not_match_on_minor_version(self) -> None:
1230         self.invokeBlack(
1231             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1232             exit_code=1,
1233             ignore_config=True,
1234         )
1235
1236     def test_required_version_does_not_match_version(self) -> None:
1237         result = BlackRunner().invoke(
1238             black.main,
1239             ["--required-version", "20.99b", "-c", "0"],
1240         )
1241         self.assertEqual(result.exit_code, 1)
1242         self.assertIn("required version", result.stderr)
1243
1244     def test_preserves_line_endings(self) -> None:
1245         with TemporaryDirectory() as workspace:
1246             test_file = Path(workspace) / "test.py"
1247             for nl in ["\n", "\r\n"]:
1248                 contents = nl.join(["def f(  ):", "    pass"])
1249                 test_file.write_bytes(contents.encode())
1250                 ff(test_file, write_back=black.WriteBack.YES)
1251                 updated_contents: bytes = test_file.read_bytes()
1252                 self.assertIn(nl.encode(), updated_contents)
1253                 if nl == "\n":
1254                     self.assertNotIn(b"\r\n", updated_contents)
1255
1256     def test_preserves_line_endings_via_stdin(self) -> None:
1257         for nl in ["\n", "\r\n"]:
1258             contents = nl.join(["def f(  ):", "    pass"])
1259             runner = BlackRunner()
1260             result = runner.invoke(
1261                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1262             )
1263             self.assertEqual(result.exit_code, 0)
1264             output = result.stdout_bytes
1265             self.assertIn(nl.encode("utf8"), output)
1266             if nl == "\n":
1267                 self.assertNotIn(b"\r\n", output)
1268
1269     def test_assert_equivalent_different_asts(self) -> None:
1270         with self.assertRaises(AssertionError):
1271             black.assert_equivalent("{}", "None")
1272
1273     def test_shhh_click(self) -> None:
1274         try:
1275             from click import _unicodefun  # type: ignore
1276         except ImportError:
1277             self.skipTest("Incompatible Click version")
1278
1279         if not hasattr(_unicodefun, "_verify_python_env"):
1280             self.skipTest("Incompatible Click version")
1281
1282         # First, let's see if Click is crashing with a preferred ASCII charset.
1283         with patch("locale.getpreferredencoding") as gpe:
1284             gpe.return_value = "ASCII"
1285             with self.assertRaises(RuntimeError):
1286                 _unicodefun._verify_python_env()
1287         # Now, let's silence Click...
1288         black.patch_click()
1289         # ...and confirm it's silent.
1290         with patch("locale.getpreferredencoding") as gpe:
1291             gpe.return_value = "ASCII"
1292             try:
1293                 _unicodefun._verify_python_env()
1294             except RuntimeError as re:
1295                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1296
1297     def test_root_logger_not_used_directly(self) -> None:
1298         def fail(*args: Any, **kwargs: Any) -> None:
1299             self.fail("Record created with root logger")
1300
1301         with patch.multiple(
1302             logging.root,
1303             debug=fail,
1304             info=fail,
1305             warning=fail,
1306             error=fail,
1307             critical=fail,
1308             log=fail,
1309         ):
1310             ff(THIS_DIR / "util.py")
1311
1312     def test_invalid_config_return_code(self) -> None:
1313         tmp_file = Path(black.dump_to_file())
1314         try:
1315             tmp_config = Path(black.dump_to_file())
1316             tmp_config.unlink()
1317             args = ["--config", str(tmp_config), str(tmp_file)]
1318             self.invokeBlack(args, exit_code=2, ignore_config=False)
1319         finally:
1320             tmp_file.unlink()
1321
1322     def test_parse_pyproject_toml(self) -> None:
1323         test_toml_file = THIS_DIR / "test.toml"
1324         config = black.parse_pyproject_toml(str(test_toml_file))
1325         self.assertEqual(config["verbose"], 1)
1326         self.assertEqual(config["check"], "no")
1327         self.assertEqual(config["diff"], "y")
1328         self.assertEqual(config["color"], True)
1329         self.assertEqual(config["line_length"], 79)
1330         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1331         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1332         self.assertEqual(config["exclude"], r"\.pyi?$")
1333         self.assertEqual(config["include"], r"\.py?$")
1334
1335     def test_read_pyproject_toml(self) -> None:
1336         test_toml_file = THIS_DIR / "test.toml"
1337         fake_ctx = FakeContext()
1338         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1339         config = fake_ctx.default_map
1340         self.assertEqual(config["verbose"], "1")
1341         self.assertEqual(config["check"], "no")
1342         self.assertEqual(config["diff"], "y")
1343         self.assertEqual(config["color"], "True")
1344         self.assertEqual(config["line_length"], "79")
1345         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1346         self.assertEqual(config["exclude"], r"\.pyi?$")
1347         self.assertEqual(config["include"], r"\.py?$")
1348
1349     @pytest.mark.incompatible_with_mypyc
1350     def test_find_project_root(self) -> None:
1351         with TemporaryDirectory() as workspace:
1352             root = Path(workspace)
1353             test_dir = root / "test"
1354             test_dir.mkdir()
1355
1356             src_dir = root / "src"
1357             src_dir.mkdir()
1358
1359             root_pyproject = root / "pyproject.toml"
1360             root_pyproject.touch()
1361             src_pyproject = src_dir / "pyproject.toml"
1362             src_pyproject.touch()
1363             src_python = src_dir / "foo.py"
1364             src_python.touch()
1365
1366             self.assertEqual(
1367                 black.find_project_root((src_dir, test_dir)),
1368                 (root.resolve(), "pyproject.toml"),
1369             )
1370             self.assertEqual(
1371                 black.find_project_root((src_dir,)),
1372                 (src_dir.resolve(), "pyproject.toml"),
1373             )
1374             self.assertEqual(
1375                 black.find_project_root((src_python,)),
1376                 (src_dir.resolve(), "pyproject.toml"),
1377             )
1378
1379     @patch(
1380         "black.files.find_user_pyproject_toml",
1381     )
1382     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1383         find_user_pyproject_toml.side_effect = RuntimeError()
1384
1385         with redirect_stderr(io.StringIO()) as stderr:
1386             result = black.files.find_pyproject_toml(
1387                 path_search_start=(str(Path.cwd().root),)
1388             )
1389
1390         assert result is None
1391         err = stderr.getvalue()
1392         assert "Ignoring user configuration" in err
1393
1394     @patch(
1395         "black.files.find_user_pyproject_toml",
1396         black.files.find_user_pyproject_toml.__wrapped__,
1397     )
1398     def test_find_user_pyproject_toml_linux(self) -> None:
1399         if system() == "Windows":
1400             return
1401
1402         # Test if XDG_CONFIG_HOME is checked
1403         with TemporaryDirectory() as workspace:
1404             tmp_user_config = Path(workspace) / "black"
1405             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1406                 self.assertEqual(
1407                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1408                 )
1409
1410         # Test fallback for XDG_CONFIG_HOME
1411         with patch.dict("os.environ"):
1412             os.environ.pop("XDG_CONFIG_HOME", None)
1413             fallback_user_config = Path("~/.config").expanduser() / "black"
1414             self.assertEqual(
1415                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1416             )
1417
1418     def test_find_user_pyproject_toml_windows(self) -> None:
1419         if system() != "Windows":
1420             return
1421
1422         user_config_path = Path.home() / ".black"
1423         self.assertEqual(
1424             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1425         )
1426
1427     def test_bpo_33660_workaround(self) -> None:
1428         if system() == "Windows":
1429             return
1430
1431         # https://bugs.python.org/issue33660
1432         root = Path("/")
1433         with change_directory(root):
1434             path = Path("workspace") / "project"
1435             report = black.Report(verbose=True)
1436             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1437             self.assertEqual(normalized_path, "workspace/project")
1438
1439     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1440         if system() != "Windows":
1441             return
1442
1443         with TemporaryDirectory() as workspace:
1444             root = Path(workspace)
1445             junction_dir = root / "junction"
1446             junction_target_outside_of_root = root / ".."
1447             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1448
1449             report = black.Report(verbose=True)
1450             normalized_path = black.normalize_path_maybe_ignore(
1451                 junction_dir, root, report
1452             )
1453             # Manually delete for Python < 3.8
1454             os.system(f"rmdir {junction_dir}")
1455
1456             self.assertEqual(normalized_path, None)
1457
1458     def test_newline_comment_interaction(self) -> None:
1459         source = "class A:\\\r\n# type: ignore\n pass\n"
1460         output = black.format_str(source, mode=DEFAULT_MODE)
1461         black.assert_stable(source, output, mode=DEFAULT_MODE)
1462
1463     def test_bpo_2142_workaround(self) -> None:
1464         # https://bugs.python.org/issue2142
1465
1466         source, _ = read_data("miscellaneous", "missing_final_newline")
1467         # read_data adds a trailing newline
1468         source = source.rstrip()
1469         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1470         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1471         diff_header = re.compile(
1472             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1473             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1474         )
1475         try:
1476             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1477             self.assertEqual(result.exit_code, 0)
1478         finally:
1479             os.unlink(tmp_file)
1480         actual = result.output
1481         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1482         self.assertEqual(actual, expected)
1483
1484     @staticmethod
1485     def compare_results(
1486         result: click.testing.Result, expected_value: str, expected_exit_code: int
1487     ) -> None:
1488         """Helper method to test the value and exit code of a click Result."""
1489         assert (
1490             result.output == expected_value
1491         ), "The output did not match the expected value."
1492         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1493
1494     def test_code_option(self) -> None:
1495         """Test the code option with no changes."""
1496         code = 'print("Hello world")\n'
1497         args = ["--code", code]
1498         result = CliRunner().invoke(black.main, args)
1499
1500         self.compare_results(result, code, 0)
1501
1502     def test_code_option_changed(self) -> None:
1503         """Test the code option when changes are required."""
1504         code = "print('hello world')"
1505         formatted = black.format_str(code, mode=DEFAULT_MODE)
1506
1507         args = ["--code", code]
1508         result = CliRunner().invoke(black.main, args)
1509
1510         self.compare_results(result, formatted, 0)
1511
1512     def test_code_option_check(self) -> None:
1513         """Test the code option when check is passed."""
1514         args = ["--check", "--code", 'print("Hello world")\n']
1515         result = CliRunner().invoke(black.main, args)
1516         self.compare_results(result, "", 0)
1517
1518     def test_code_option_check_changed(self) -> None:
1519         """Test the code option when changes are required, and check is passed."""
1520         args = ["--check", "--code", "print('hello world')"]
1521         result = CliRunner().invoke(black.main, args)
1522         self.compare_results(result, "", 1)
1523
1524     def test_code_option_diff(self) -> None:
1525         """Test the code option when diff is passed."""
1526         code = "print('hello world')"
1527         formatted = black.format_str(code, mode=DEFAULT_MODE)
1528         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1529
1530         args = ["--diff", "--code", code]
1531         result = CliRunner().invoke(black.main, args)
1532
1533         # Remove time from diff
1534         output = DIFF_TIME.sub("", result.output)
1535
1536         assert output == result_diff, "The output did not match the expected value."
1537         assert result.exit_code == 0, "The exit code is incorrect."
1538
1539     def test_code_option_color_diff(self) -> None:
1540         """Test the code option when color and diff are passed."""
1541         code = "print('hello world')"
1542         formatted = black.format_str(code, mode=DEFAULT_MODE)
1543
1544         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1545         result_diff = color_diff(result_diff)
1546
1547         args = ["--diff", "--color", "--code", code]
1548         result = CliRunner().invoke(black.main, args)
1549
1550         # Remove time from diff
1551         output = DIFF_TIME.sub("", result.output)
1552
1553         assert output == result_diff, "The output did not match the expected value."
1554         assert result.exit_code == 0, "The exit code is incorrect."
1555
1556     @pytest.mark.incompatible_with_mypyc
1557     def test_code_option_safe(self) -> None:
1558         """Test that the code option throws an error when the sanity checks fail."""
1559         # Patch black.assert_equivalent to ensure the sanity checks fail
1560         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1561             code = 'print("Hello world")'
1562             error_msg = f"{code}\nerror: cannot format <string>: \n"
1563
1564             args = ["--safe", "--code", code]
1565             result = CliRunner().invoke(black.main, args)
1566
1567             self.compare_results(result, error_msg, 123)
1568
1569     def test_code_option_fast(self) -> None:
1570         """Test that the code option ignores errors when the sanity checks fail."""
1571         # Patch black.assert_equivalent to ensure the sanity checks fail
1572         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1573             code = 'print("Hello world")'
1574             formatted = black.format_str(code, mode=DEFAULT_MODE)
1575
1576             args = ["--fast", "--code", code]
1577             result = CliRunner().invoke(black.main, args)
1578
1579             self.compare_results(result, formatted, 0)
1580
1581     @pytest.mark.incompatible_with_mypyc
1582     def test_code_option_config(self) -> None:
1583         """
1584         Test that the code option finds the pyproject.toml in the current directory.
1585         """
1586         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1587             args = ["--code", "print"]
1588             # This is the only directory known to contain a pyproject.toml
1589             with change_directory(PROJECT_ROOT):
1590                 CliRunner().invoke(black.main, args)
1591                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1592
1593             assert (
1594                 len(parse.mock_calls) >= 1
1595             ), "Expected config parse to be called with the current directory."
1596
1597             _, call_args, _ = parse.mock_calls[0]
1598             assert (
1599                 call_args[0].lower() == str(pyproject_path).lower()
1600             ), "Incorrect config loaded."
1601
1602     @pytest.mark.incompatible_with_mypyc
1603     def test_code_option_parent_config(self) -> None:
1604         """
1605         Test that the code option finds the pyproject.toml in the parent directory.
1606         """
1607         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1608             with change_directory(THIS_DIR):
1609                 args = ["--code", "print"]
1610                 CliRunner().invoke(black.main, args)
1611
1612                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1613                 assert (
1614                     len(parse.mock_calls) >= 1
1615                 ), "Expected config parse to be called with the current directory."
1616
1617                 _, call_args, _ = parse.mock_calls[0]
1618                 assert (
1619                     call_args[0].lower() == str(pyproject_path).lower()
1620                 ), "Incorrect config loaded."
1621
1622     def test_for_handled_unexpected_eof_error(self) -> None:
1623         """
1624         Test that an unexpected EOF SyntaxError is nicely presented.
1625         """
1626         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1627             black.lib2to3_parse("print(", {})
1628
1629         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1630
1631     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1632         with pytest.raises(AssertionError) as err:
1633             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1634
1635         err.match("--safe")
1636         # Unfortunately the SyntaxError message has changed in newer versions so we
1637         # can't match it directly.
1638         err.match("invalid character")
1639         err.match(r"\(<unknown>, line 1\)")
1640
1641
1642 class TestCaching:
1643     def test_get_cache_dir(
1644         self,
1645         tmp_path: Path,
1646         monkeypatch: pytest.MonkeyPatch,
1647     ) -> None:
1648         # Create multiple cache directories
1649         workspace1 = tmp_path / "ws1"
1650         workspace1.mkdir()
1651         workspace2 = tmp_path / "ws2"
1652         workspace2.mkdir()
1653
1654         # Force user_cache_dir to use the temporary directory for easier assertions
1655         patch_user_cache_dir = patch(
1656             target="black.cache.user_cache_dir",
1657             autospec=True,
1658             return_value=str(workspace1),
1659         )
1660
1661         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1662         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1663         with patch_user_cache_dir:
1664             assert get_cache_dir() == workspace1
1665
1666         # If it is set, use the path provided in the env var.
1667         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1668         assert get_cache_dir() == workspace2
1669
1670     def test_cache_broken_file(self) -> None:
1671         mode = DEFAULT_MODE
1672         with cache_dir() as workspace:
1673             cache_file = get_cache_file(mode)
1674             cache_file.write_text("this is not a pickle")
1675             assert black.read_cache(mode) == {}
1676             src = (workspace / "test.py").resolve()
1677             src.write_text("print('hello')")
1678             invokeBlack([str(src)])
1679             cache = black.read_cache(mode)
1680             assert str(src) in cache
1681
1682     def test_cache_single_file_already_cached(self) -> None:
1683         mode = DEFAULT_MODE
1684         with cache_dir() as workspace:
1685             src = (workspace / "test.py").resolve()
1686             src.write_text("print('hello')")
1687             black.write_cache({}, [src], mode)
1688             invokeBlack([str(src)])
1689             assert src.read_text() == "print('hello')"
1690
1691     @event_loop()
1692     def test_cache_multiple_files(self) -> None:
1693         mode = DEFAULT_MODE
1694         with cache_dir() as workspace, patch(
1695             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1696         ):
1697             one = (workspace / "one.py").resolve()
1698             with one.open("w") as fobj:
1699                 fobj.write("print('hello')")
1700             two = (workspace / "two.py").resolve()
1701             with two.open("w") as fobj:
1702                 fobj.write("print('hello')")
1703             black.write_cache({}, [one], mode)
1704             invokeBlack([str(workspace)])
1705             with one.open("r") as fobj:
1706                 assert fobj.read() == "print('hello')"
1707             with two.open("r") as fobj:
1708                 assert fobj.read() == 'print("hello")\n'
1709             cache = black.read_cache(mode)
1710             assert str(one) in cache
1711             assert str(two) in cache
1712
1713     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1714     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1715         mode = DEFAULT_MODE
1716         with cache_dir() as workspace:
1717             src = (workspace / "test.py").resolve()
1718             with src.open("w") as fobj:
1719                 fobj.write("print('hello')")
1720             with patch("black.read_cache") as read_cache, patch(
1721                 "black.write_cache"
1722             ) as write_cache:
1723                 cmd = [str(src), "--diff"]
1724                 if color:
1725                     cmd.append("--color")
1726                 invokeBlack(cmd)
1727                 cache_file = get_cache_file(mode)
1728                 assert cache_file.exists() is False
1729                 write_cache.assert_not_called()
1730                 read_cache.assert_not_called()
1731
1732     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1733     @event_loop()
1734     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1735         with cache_dir() as workspace:
1736             for tag in range(0, 4):
1737                 src = (workspace / f"test{tag}.py").resolve()
1738                 with src.open("w") as fobj:
1739                     fobj.write("print('hello')")
1740             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1741                 cmd = ["--diff", str(workspace)]
1742                 if color:
1743                     cmd.append("--color")
1744                 invokeBlack(cmd, exit_code=0)
1745                 # this isn't quite doing what we want, but if it _isn't_
1746                 # called then we cannot be using the lock it provides
1747                 mgr.assert_called()
1748
1749     def test_no_cache_when_stdin(self) -> None:
1750         mode = DEFAULT_MODE
1751         with cache_dir():
1752             result = CliRunner().invoke(
1753                 black.main, ["-"], input=BytesIO(b"print('hello')")
1754             )
1755             assert not result.exit_code
1756             cache_file = get_cache_file(mode)
1757             assert not cache_file.exists()
1758
1759     def test_read_cache_no_cachefile(self) -> None:
1760         mode = DEFAULT_MODE
1761         with cache_dir():
1762             assert black.read_cache(mode) == {}
1763
1764     def test_write_cache_read_cache(self) -> None:
1765         mode = DEFAULT_MODE
1766         with cache_dir() as workspace:
1767             src = (workspace / "test.py").resolve()
1768             src.touch()
1769             black.write_cache({}, [src], mode)
1770             cache = black.read_cache(mode)
1771             assert str(src) in cache
1772             assert cache[str(src)] == black.get_cache_info(src)
1773
1774     def test_filter_cached(self) -> None:
1775         with TemporaryDirectory() as workspace:
1776             path = Path(workspace)
1777             uncached = (path / "uncached").resolve()
1778             cached = (path / "cached").resolve()
1779             cached_but_changed = (path / "changed").resolve()
1780             uncached.touch()
1781             cached.touch()
1782             cached_but_changed.touch()
1783             cache = {
1784                 str(cached): black.get_cache_info(cached),
1785                 str(cached_but_changed): (0.0, 0),
1786             }
1787             todo, done = black.filter_cached(
1788                 cache, {uncached, cached, cached_but_changed}
1789             )
1790             assert todo == {uncached, cached_but_changed}
1791             assert done == {cached}
1792
1793     def test_write_cache_creates_directory_if_needed(self) -> None:
1794         mode = DEFAULT_MODE
1795         with cache_dir(exists=False) as workspace:
1796             assert not workspace.exists()
1797             black.write_cache({}, [], mode)
1798             assert workspace.exists()
1799
1800     @event_loop()
1801     def test_failed_formatting_does_not_get_cached(self) -> None:
1802         mode = DEFAULT_MODE
1803         with cache_dir() as workspace, patch(
1804             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1805         ):
1806             failing = (workspace / "failing.py").resolve()
1807             with failing.open("w") as fobj:
1808                 fobj.write("not actually python")
1809             clean = (workspace / "clean.py").resolve()
1810             with clean.open("w") as fobj:
1811                 fobj.write('print("hello")\n')
1812             invokeBlack([str(workspace)], exit_code=123)
1813             cache = black.read_cache(mode)
1814             assert str(failing) not in cache
1815             assert str(clean) in cache
1816
1817     def test_write_cache_write_fail(self) -> None:
1818         mode = DEFAULT_MODE
1819         with cache_dir(), patch.object(Path, "open") as mock:
1820             mock.side_effect = OSError
1821             black.write_cache({}, [], mode)
1822
1823     def test_read_cache_line_lengths(self) -> None:
1824         mode = DEFAULT_MODE
1825         short_mode = replace(DEFAULT_MODE, line_length=1)
1826         with cache_dir() as workspace:
1827             path = (workspace / "file.py").resolve()
1828             path.touch()
1829             black.write_cache({}, [path], mode)
1830             one = black.read_cache(mode)
1831             assert str(path) in one
1832             two = black.read_cache(short_mode)
1833             assert str(path) not in two
1834
1835
1836 def assert_collected_sources(
1837     src: Sequence[Union[str, Path]],
1838     expected: Sequence[Union[str, Path]],
1839     *,
1840     ctx: Optional[FakeContext] = None,
1841     exclude: Optional[str] = None,
1842     include: Optional[str] = None,
1843     extend_exclude: Optional[str] = None,
1844     force_exclude: Optional[str] = None,
1845     stdin_filename: Optional[str] = None,
1846 ) -> None:
1847     gs_src = tuple(str(Path(s)) for s in src)
1848     gs_expected = [Path(s) for s in expected]
1849     gs_exclude = None if exclude is None else compile_pattern(exclude)
1850     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1851     gs_extend_exclude = (
1852         None if extend_exclude is None else compile_pattern(extend_exclude)
1853     )
1854     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1855     collected = black.get_sources(
1856         ctx=ctx or FakeContext(),
1857         src=gs_src,
1858         quiet=False,
1859         verbose=False,
1860         include=gs_include,
1861         exclude=gs_exclude,
1862         extend_exclude=gs_extend_exclude,
1863         force_exclude=gs_force_exclude,
1864         report=black.Report(),
1865         stdin_filename=stdin_filename,
1866     )
1867     assert sorted(collected) == sorted(gs_expected)
1868
1869
1870 class TestFileCollection:
1871     def test_include_exclude(self) -> None:
1872         path = THIS_DIR / "data" / "include_exclude_tests"
1873         src = [path]
1874         expected = [
1875             Path(path / "b/dont_exclude/a.py"),
1876             Path(path / "b/dont_exclude/a.pyi"),
1877         ]
1878         assert_collected_sources(
1879             src,
1880             expected,
1881             include=r"\.pyi?$",
1882             exclude=r"/exclude/|/\.definitely_exclude/",
1883         )
1884
1885     def test_gitignore_used_as_default(self) -> None:
1886         base = Path(DATA_DIR / "include_exclude_tests")
1887         expected = [
1888             base / "b/.definitely_exclude/a.py",
1889             base / "b/.definitely_exclude/a.pyi",
1890         ]
1891         src = [base / "b/"]
1892         ctx = FakeContext()
1893         ctx.obj["root"] = base
1894         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1895
1896     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1897     def test_exclude_for_issue_1572(self) -> None:
1898         # Exclude shouldn't touch files that were explicitly given to Black through the
1899         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1900         # https://github.com/psf/black/issues/1572
1901         path = DATA_DIR / "include_exclude_tests"
1902         src = [path / "b/exclude/a.py"]
1903         expected = [path / "b/exclude/a.py"]
1904         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1905
1906     def test_gitignore_exclude(self) -> None:
1907         path = THIS_DIR / "data" / "include_exclude_tests"
1908         include = re.compile(r"\.pyi?$")
1909         exclude = re.compile(r"")
1910         report = black.Report()
1911         gitignore = PathSpec.from_lines(
1912             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1913         )
1914         sources: List[Path] = []
1915         expected = [
1916             Path(path / "b/dont_exclude/a.py"),
1917             Path(path / "b/dont_exclude/a.pyi"),
1918         ]
1919         this_abs = THIS_DIR.resolve()
1920         sources.extend(
1921             black.gen_python_files(
1922                 path.iterdir(),
1923                 this_abs,
1924                 include,
1925                 exclude,
1926                 None,
1927                 None,
1928                 report,
1929                 gitignore,
1930                 verbose=False,
1931                 quiet=False,
1932             )
1933         )
1934         assert sorted(expected) == sorted(sources)
1935
1936     def test_nested_gitignore(self) -> None:
1937         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1938         include = re.compile(r"\.pyi?$")
1939         exclude = re.compile(r"")
1940         root_gitignore = black.files.get_gitignore(path)
1941         report = black.Report()
1942         expected: List[Path] = [
1943             Path(path / "x.py"),
1944             Path(path / "root/b.py"),
1945             Path(path / "root/c.py"),
1946             Path(path / "root/child/c.py"),
1947         ]
1948         this_abs = THIS_DIR.resolve()
1949         sources = list(
1950             black.gen_python_files(
1951                 path.iterdir(),
1952                 this_abs,
1953                 include,
1954                 exclude,
1955                 None,
1956                 None,
1957                 report,
1958                 root_gitignore,
1959                 verbose=False,
1960                 quiet=False,
1961             )
1962         )
1963         assert sorted(expected) == sorted(sources)
1964
1965     def test_invalid_gitignore(self) -> None:
1966         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1967         empty_config = path / "pyproject.toml"
1968         result = BlackRunner().invoke(
1969             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1970         )
1971         assert result.exit_code == 1
1972         assert result.stderr_bytes is not None
1973
1974         gitignore = path / ".gitignore"
1975         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1976
1977     def test_invalid_nested_gitignore(self) -> None:
1978         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1979         empty_config = path / "pyproject.toml"
1980         result = BlackRunner().invoke(
1981             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1982         )
1983         assert result.exit_code == 1
1984         assert result.stderr_bytes is not None
1985
1986         gitignore = path / "a" / ".gitignore"
1987         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1988
1989     def test_empty_include(self) -> None:
1990         path = DATA_DIR / "include_exclude_tests"
1991         src = [path]
1992         expected = [
1993             Path(path / "b/exclude/a.pie"),
1994             Path(path / "b/exclude/a.py"),
1995             Path(path / "b/exclude/a.pyi"),
1996             Path(path / "b/dont_exclude/a.pie"),
1997             Path(path / "b/dont_exclude/a.py"),
1998             Path(path / "b/dont_exclude/a.pyi"),
1999             Path(path / "b/.definitely_exclude/a.pie"),
2000             Path(path / "b/.definitely_exclude/a.py"),
2001             Path(path / "b/.definitely_exclude/a.pyi"),
2002             Path(path / ".gitignore"),
2003             Path(path / "pyproject.toml"),
2004         ]
2005         # Setting exclude explicitly to an empty string to block .gitignore usage.
2006         assert_collected_sources(src, expected, include="", exclude="")
2007
2008     def test_extend_exclude(self) -> None:
2009         path = DATA_DIR / "include_exclude_tests"
2010         src = [path]
2011         expected = [
2012             Path(path / "b/exclude/a.py"),
2013             Path(path / "b/dont_exclude/a.py"),
2014         ]
2015         assert_collected_sources(
2016             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2017         )
2018
2019     @pytest.mark.incompatible_with_mypyc
2020     def test_symlink_out_of_root_directory(self) -> None:
2021         path = MagicMock()
2022         root = THIS_DIR.resolve()
2023         child = MagicMock()
2024         include = re.compile(black.DEFAULT_INCLUDES)
2025         exclude = re.compile(black.DEFAULT_EXCLUDES)
2026         report = black.Report()
2027         gitignore = PathSpec.from_lines("gitwildmatch", [])
2028         # `child` should behave like a symlink which resolved path is clearly
2029         # outside of the `root` directory.
2030         path.iterdir.return_value = [child]
2031         child.resolve.return_value = Path("/a/b/c")
2032         child.as_posix.return_value = "/a/b/c"
2033         try:
2034             list(
2035                 black.gen_python_files(
2036                     path.iterdir(),
2037                     root,
2038                     include,
2039                     exclude,
2040                     None,
2041                     None,
2042                     report,
2043                     gitignore,
2044                     verbose=False,
2045                     quiet=False,
2046                 )
2047             )
2048         except ValueError as ve:
2049             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2050         path.iterdir.assert_called_once()
2051         child.resolve.assert_called_once()
2052
2053     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2054     def test_get_sources_with_stdin(self) -> None:
2055         src = ["-"]
2056         expected = ["-"]
2057         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2058
2059     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2060     def test_get_sources_with_stdin_filename(self) -> None:
2061         src = ["-"]
2062         stdin_filename = str(THIS_DIR / "data/collections.py")
2063         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2064         assert_collected_sources(
2065             src,
2066             expected,
2067             exclude=r"/exclude/a\.py",
2068             stdin_filename=stdin_filename,
2069         )
2070
2071     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2072     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2073         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2074         # file being passed directly. This is the same as
2075         # test_exclude_for_issue_1572
2076         path = DATA_DIR / "include_exclude_tests"
2077         src = ["-"]
2078         stdin_filename = str(path / "b/exclude/a.py")
2079         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2080         assert_collected_sources(
2081             src,
2082             expected,
2083             exclude=r"/exclude/|a\.py",
2084             stdin_filename=stdin_filename,
2085         )
2086
2087     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2088     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2089         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2090         # file being passed directly. This is the same as
2091         # test_exclude_for_issue_1572
2092         src = ["-"]
2093         path = THIS_DIR / "data" / "include_exclude_tests"
2094         stdin_filename = str(path / "b/exclude/a.py")
2095         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2096         assert_collected_sources(
2097             src,
2098             expected,
2099             extend_exclude=r"/exclude/|a\.py",
2100             stdin_filename=stdin_filename,
2101         )
2102
2103     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2104     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2105         # Force exclude should exclude the file when passing it through
2106         # stdin_filename
2107         path = THIS_DIR / "data" / "include_exclude_tests"
2108         stdin_filename = str(path / "b/exclude/a.py")
2109         assert_collected_sources(
2110             src=["-"],
2111             expected=[],
2112             force_exclude=r"/exclude/|a\.py",
2113             stdin_filename=stdin_filename,
2114         )
2115
2116
2117 try:
2118     with open(black.__file__, "r", encoding="utf-8") as _bf:
2119         black_source_lines = _bf.readlines()
2120 except UnicodeDecodeError:
2121     if not black.COMPILED:
2122         raise
2123
2124
2125 def tracefunc(
2126     frame: types.FrameType, event: str, arg: Any
2127 ) -> Callable[[types.FrameType, str, Any], Any]:
2128     """Show function calls `from black/__init__.py` as they happen.
2129
2130     Register this with `sys.settrace()` in a test you're debugging.
2131     """
2132     if event != "call":
2133         return tracefunc
2134
2135     stack = len(inspect.stack()) - 19
2136     stack *= 2
2137     filename = frame.f_code.co_filename
2138     lineno = frame.f_lineno
2139     func_sig_lineno = lineno - 1
2140     funcname = black_source_lines[func_sig_lineno].strip()
2141     while funcname.startswith("@"):
2142         func_sig_lineno += 1
2143         funcname = black_source_lines[func_sig_lineno].strip()
2144     if "black/__init__.py" in filename:
2145         print(f"{' ' * stack}{lineno}:{funcname}")
2146     return tracefunc