]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove hard coded test cases (#3062)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63     get_case_path,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     @patch("black.dump_to_file", dump_to_stderr)
314     def test_string_quotes(self) -> None:
315         source, expected = read_data("miscellaneous", "string_quotes")
316         mode = black.Mode(preview=True)
317         assert_format(source, expected, mode)
318         mode = replace(mode, string_normalization=False)
319         not_normalized = fs(source, mode=mode)
320         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
321         black.assert_equivalent(source, not_normalized)
322         black.assert_stable(source, not_normalized, mode=mode)
323
324     def test_skip_magic_trailing_comma(self) -> None:
325         source, _ = read_data("simple_cases", "expression")
326         expected, _ = read_data(
327             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
328         )
329         tmp_file = Path(black.dump_to_file(source))
330         diff_header = re.compile(
331             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
332             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
333         )
334         try:
335             result = BlackRunner().invoke(
336                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
337             )
338             self.assertEqual(result.exit_code, 0)
339         finally:
340             os.unlink(tmp_file)
341         actual = result.output
342         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
343         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
344         if expected != actual:
345             dump = black.dump_to_file(actual)
346             msg = (
347                 "Expected diff isn't equal to the actual. If you made changes to"
348                 " expression.py and this is an anticipated difference, overwrite"
349                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
350             )
351             self.assertEqual(expected, actual, msg)
352
353     @patch("black.dump_to_file", dump_to_stderr)
354     def test_async_as_identifier(self) -> None:
355         source_path = get_case_path("miscellaneous", "async_as_identifier")
356         source, expected = read_data_from_file(source_path)
357         actual = fs(source)
358         self.assertFormatEqual(expected, actual)
359         major, minor = sys.version_info[:2]
360         if major < 3 or (major <= 3 and minor < 7):
361             black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, DEFAULT_MODE)
363         # ensure black can parse this when the target is 3.6
364         self.invokeBlack([str(source_path), "--target-version", "py36"])
365         # but not on 3.7, because async/await is no longer an identifier
366         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_python37(self) -> None:
370         source_path = get_case_path("py_37", "python37")
371         source, expected = read_data_from_file(source_path)
372         actual = fs(source)
373         self.assertFormatEqual(expected, actual)
374         major, minor = sys.version_info[:2]
375         if major > 3 or (major == 3 and minor >= 7):
376             black.assert_equivalent(source, actual)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378         # ensure black can parse this when the target is 3.7
379         self.invokeBlack([str(source_path), "--target-version", "py37"])
380         # but not on 3.6, because we use async as a reserved keyword
381         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
382
383     def test_tab_comment_indentation(self) -> None:
384         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
385         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
386         self.assertFormatEqual(contents_spc, fs(contents_spc))
387         self.assertFormatEqual(contents_spc, fs(contents_tab))
388
389         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
390         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
391         self.assertFormatEqual(contents_spc, fs(contents_spc))
392         self.assertFormatEqual(contents_spc, fs(contents_tab))
393
394         # mixed tabs and spaces (valid Python 2 code)
395         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
396         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
397         self.assertFormatEqual(contents_spc, fs(contents_spc))
398         self.assertFormatEqual(contents_spc, fs(contents_tab))
399
400         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
401         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
402         self.assertFormatEqual(contents_spc, fs(contents_spc))
403         self.assertFormatEqual(contents_spc, fs(contents_tab))
404
405     def test_report_verbose(self) -> None:
406         report = Report(verbose=True)
407         out_lines = []
408         err_lines = []
409
410         def out(msg: str, **kwargs: Any) -> None:
411             out_lines.append(msg)
412
413         def err(msg: str, **kwargs: Any) -> None:
414             err_lines.append(msg)
415
416         with patch("black.output._out", out), patch("black.output._err", err):
417             report.done(Path("f1"), black.Changed.NO)
418             self.assertEqual(len(out_lines), 1)
419             self.assertEqual(len(err_lines), 0)
420             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
421             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
422             self.assertEqual(report.return_code, 0)
423             report.done(Path("f2"), black.Changed.YES)
424             self.assertEqual(len(out_lines), 2)
425             self.assertEqual(len(err_lines), 0)
426             self.assertEqual(out_lines[-1], "reformatted f2")
427             self.assertEqual(
428                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
429             )
430             report.done(Path("f3"), black.Changed.CACHED)
431             self.assertEqual(len(out_lines), 3)
432             self.assertEqual(len(err_lines), 0)
433             self.assertEqual(
434                 out_lines[-1], "f3 wasn't modified on disk since last run."
435             )
436             self.assertEqual(
437                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
438             )
439             self.assertEqual(report.return_code, 0)
440             report.check = True
441             self.assertEqual(report.return_code, 1)
442             report.check = False
443             report.failed(Path("e1"), "boom")
444             self.assertEqual(len(out_lines), 3)
445             self.assertEqual(len(err_lines), 1)
446             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
447             self.assertEqual(
448                 unstyle(str(report)),
449                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
450                 " reformat.",
451             )
452             self.assertEqual(report.return_code, 123)
453             report.done(Path("f3"), black.Changed.YES)
454             self.assertEqual(len(out_lines), 4)
455             self.assertEqual(len(err_lines), 1)
456             self.assertEqual(out_lines[-1], "reformatted f3")
457             self.assertEqual(
458                 unstyle(str(report)),
459                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
460                 " reformat.",
461             )
462             self.assertEqual(report.return_code, 123)
463             report.failed(Path("e2"), "boom")
464             self.assertEqual(len(out_lines), 4)
465             self.assertEqual(len(err_lines), 2)
466             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
467             self.assertEqual(
468                 unstyle(str(report)),
469                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
470                 " reformat.",
471             )
472             self.assertEqual(report.return_code, 123)
473             report.path_ignored(Path("wat"), "no match")
474             self.assertEqual(len(out_lines), 5)
475             self.assertEqual(len(err_lines), 2)
476             self.assertEqual(out_lines[-1], "wat ignored: no match")
477             self.assertEqual(
478                 unstyle(str(report)),
479                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
480                 " reformat.",
481             )
482             self.assertEqual(report.return_code, 123)
483             report.done(Path("f4"), black.Changed.NO)
484             self.assertEqual(len(out_lines), 6)
485             self.assertEqual(len(err_lines), 2)
486             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
487             self.assertEqual(
488                 unstyle(str(report)),
489                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
490                 " reformat.",
491             )
492             self.assertEqual(report.return_code, 123)
493             report.check = True
494             self.assertEqual(
495                 unstyle(str(report)),
496                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
497                 " would fail to reformat.",
498             )
499             report.check = False
500             report.diff = True
501             self.assertEqual(
502                 unstyle(str(report)),
503                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
504                 " would fail to reformat.",
505             )
506
507     def test_report_quiet(self) -> None:
508         report = Report(quiet=True)
509         out_lines = []
510         err_lines = []
511
512         def out(msg: str, **kwargs: Any) -> None:
513             out_lines.append(msg)
514
515         def err(msg: str, **kwargs: Any) -> None:
516             err_lines.append(msg)
517
518         with patch("black.output._out", out), patch("black.output._err", err):
519             report.done(Path("f1"), black.Changed.NO)
520             self.assertEqual(len(out_lines), 0)
521             self.assertEqual(len(err_lines), 0)
522             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
523             self.assertEqual(report.return_code, 0)
524             report.done(Path("f2"), black.Changed.YES)
525             self.assertEqual(len(out_lines), 0)
526             self.assertEqual(len(err_lines), 0)
527             self.assertEqual(
528                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
529             )
530             report.done(Path("f3"), black.Changed.CACHED)
531             self.assertEqual(len(out_lines), 0)
532             self.assertEqual(len(err_lines), 0)
533             self.assertEqual(
534                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
535             )
536             self.assertEqual(report.return_code, 0)
537             report.check = True
538             self.assertEqual(report.return_code, 1)
539             report.check = False
540             report.failed(Path("e1"), "boom")
541             self.assertEqual(len(out_lines), 0)
542             self.assertEqual(len(err_lines), 1)
543             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
544             self.assertEqual(
545                 unstyle(str(report)),
546                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
547                 " reformat.",
548             )
549             self.assertEqual(report.return_code, 123)
550             report.done(Path("f3"), black.Changed.YES)
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 1)
553             self.assertEqual(
554                 unstyle(str(report)),
555                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
556                 " reformat.",
557             )
558             self.assertEqual(report.return_code, 123)
559             report.failed(Path("e2"), "boom")
560             self.assertEqual(len(out_lines), 0)
561             self.assertEqual(len(err_lines), 2)
562             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
563             self.assertEqual(
564                 unstyle(str(report)),
565                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
566                 " reformat.",
567             )
568             self.assertEqual(report.return_code, 123)
569             report.path_ignored(Path("wat"), "no match")
570             self.assertEqual(len(out_lines), 0)
571             self.assertEqual(len(err_lines), 2)
572             self.assertEqual(
573                 unstyle(str(report)),
574                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
575                 " reformat.",
576             )
577             self.assertEqual(report.return_code, 123)
578             report.done(Path("f4"), black.Changed.NO)
579             self.assertEqual(len(out_lines), 0)
580             self.assertEqual(len(err_lines), 2)
581             self.assertEqual(
582                 unstyle(str(report)),
583                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
584                 " reformat.",
585             )
586             self.assertEqual(report.return_code, 123)
587             report.check = True
588             self.assertEqual(
589                 unstyle(str(report)),
590                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
591                 " would fail to reformat.",
592             )
593             report.check = False
594             report.diff = True
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
598                 " would fail to reformat.",
599             )
600
601     def test_report_normal(self) -> None:
602         report = black.Report()
603         out_lines = []
604         err_lines = []
605
606         def out(msg: str, **kwargs: Any) -> None:
607             out_lines.append(msg)
608
609         def err(msg: str, **kwargs: Any) -> None:
610             err_lines.append(msg)
611
612         with patch("black.output._out", out), patch("black.output._err", err):
613             report.done(Path("f1"), black.Changed.NO)
614             self.assertEqual(len(out_lines), 0)
615             self.assertEqual(len(err_lines), 0)
616             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
617             self.assertEqual(report.return_code, 0)
618             report.done(Path("f2"), black.Changed.YES)
619             self.assertEqual(len(out_lines), 1)
620             self.assertEqual(len(err_lines), 0)
621             self.assertEqual(out_lines[-1], "reformatted f2")
622             self.assertEqual(
623                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
624             )
625             report.done(Path("f3"), black.Changed.CACHED)
626             self.assertEqual(len(out_lines), 1)
627             self.assertEqual(len(err_lines), 0)
628             self.assertEqual(out_lines[-1], "reformatted f2")
629             self.assertEqual(
630                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
631             )
632             self.assertEqual(report.return_code, 0)
633             report.check = True
634             self.assertEqual(report.return_code, 1)
635             report.check = False
636             report.failed(Path("e1"), "boom")
637             self.assertEqual(len(out_lines), 1)
638             self.assertEqual(len(err_lines), 1)
639             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
643                 " reformat.",
644             )
645             self.assertEqual(report.return_code, 123)
646             report.done(Path("f3"), black.Changed.YES)
647             self.assertEqual(len(out_lines), 2)
648             self.assertEqual(len(err_lines), 1)
649             self.assertEqual(out_lines[-1], "reformatted f3")
650             self.assertEqual(
651                 unstyle(str(report)),
652                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
653                 " reformat.",
654             )
655             self.assertEqual(report.return_code, 123)
656             report.failed(Path("e2"), "boom")
657             self.assertEqual(len(out_lines), 2)
658             self.assertEqual(len(err_lines), 2)
659             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.path_ignored(Path("wat"), "no match")
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 2)
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
672                 " reformat.",
673             )
674             self.assertEqual(report.return_code, 123)
675             report.done(Path("f4"), black.Changed.NO)
676             self.assertEqual(len(out_lines), 2)
677             self.assertEqual(len(err_lines), 2)
678             self.assertEqual(
679                 unstyle(str(report)),
680                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
681                 " reformat.",
682             )
683             self.assertEqual(report.return_code, 123)
684             report.check = True
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
688                 " would fail to reformat.",
689             )
690             report.check = False
691             report.diff = True
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
695                 " would fail to reformat.",
696             )
697
698     def test_lib2to3_parse(self) -> None:
699         with self.assertRaises(black.InvalidInput):
700             black.lib2to3_parse("invalid syntax")
701
702         straddling = "x + y"
703         black.lib2to3_parse(straddling)
704         black.lib2to3_parse(straddling, {TargetVersion.PY36})
705
706         py2_only = "print x"
707         with self.assertRaises(black.InvalidInput):
708             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
709
710         py3_only = "exec(x, end=y)"
711         black.lib2to3_parse(py3_only)
712         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
713
714     def test_get_features_used_decorator(self) -> None:
715         # Test the feature detection of new decorator syntax
716         # since this makes some test cases of test_get_features_used()
717         # fails if it fails, this is tested first so that a useful case
718         # is identified
719         simples, relaxed = read_data("miscellaneous", "decorators")
720         # skip explanation comments at the top of the file
721         for simple_test in simples.split("##")[1:]:
722             node = black.lib2to3_parse(simple_test)
723             decorator = str(node.children[0].children[0]).strip()
724             self.assertNotIn(
725                 Feature.RELAXED_DECORATORS,
726                 black.get_features_used(node),
727                 msg=(
728                     f"decorator '{decorator}' follows python<=3.8 syntax"
729                     "but is detected as 3.9+"
730                     # f"The full node is\n{node!r}"
731                 ),
732             )
733         # skip the '# output' comment at the top of the output part
734         for relaxed_test in relaxed.split("##")[1:]:
735             node = black.lib2to3_parse(relaxed_test)
736             decorator = str(node.children[0].children[0]).strip()
737             self.assertIn(
738                 Feature.RELAXED_DECORATORS,
739                 black.get_features_used(node),
740                 msg=(
741                     f"decorator '{decorator}' uses python3.9+ syntax"
742                     "but is detected as python<=3.8"
743                     # f"The full node is\n{node!r}"
744                 ),
745             )
746
747     def test_get_features_used(self) -> None:
748         node = black.lib2to3_parse("def f(*, arg): ...\n")
749         self.assertEqual(black.get_features_used(node), set())
750         node = black.lib2to3_parse("def f(*, arg,): ...\n")
751         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
752         node = black.lib2to3_parse("f(*arg,)\n")
753         self.assertEqual(
754             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
755         )
756         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
757         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
758         node = black.lib2to3_parse("123_456\n")
759         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
760         node = black.lib2to3_parse("123456\n")
761         self.assertEqual(black.get_features_used(node), set())
762         source, expected = read_data("simple_cases", "function")
763         node = black.lib2to3_parse(source)
764         expected_features = {
765             Feature.TRAILING_COMMA_IN_CALL,
766             Feature.TRAILING_COMMA_IN_DEF,
767             Feature.F_STRINGS,
768         }
769         self.assertEqual(black.get_features_used(node), expected_features)
770         node = black.lib2to3_parse(expected)
771         self.assertEqual(black.get_features_used(node), expected_features)
772         source, expected = read_data("simple_cases", "expression")
773         node = black.lib2to3_parse(source)
774         self.assertEqual(black.get_features_used(node), set())
775         node = black.lib2to3_parse(expected)
776         self.assertEqual(black.get_features_used(node), set())
777         node = black.lib2to3_parse("lambda a, /, b: ...")
778         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
779         node = black.lib2to3_parse("def fn(a, /, b): ...")
780         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
781         node = black.lib2to3_parse("def fn(): yield a, b")
782         self.assertEqual(black.get_features_used(node), set())
783         node = black.lib2to3_parse("def fn(): return a, b")
784         self.assertEqual(black.get_features_used(node), set())
785         node = black.lib2to3_parse("def fn(): yield *b, c")
786         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
787         node = black.lib2to3_parse("def fn(): return a, *b, c")
788         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
789         node = black.lib2to3_parse("x = a, *b, c")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = regular")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Any = (regular, regular)")
794         self.assertEqual(black.get_features_used(node), set())
795         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
796         self.assertEqual(black.get_features_used(node), set())
797         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
798         self.assertEqual(
799             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
800         )
801         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
802         self.assertEqual(black.get_features_used(node), set())
803         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
804         self.assertEqual(black.get_features_used(node), set())
805         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
806         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
807
808     def test_get_features_used_for_future_flags(self) -> None:
809         for src, features in [
810             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
811             (
812                 "from __future__ import (other, annotations)",
813                 {Feature.FUTURE_ANNOTATIONS},
814             ),
815             ("a = 1 + 2\nfrom something import annotations", set()),
816             ("from __future__ import x, y", set()),
817         ]:
818             with self.subTest(src=src, features=features):
819                 node = black.lib2to3_parse(src)
820                 future_imports = black.get_future_imports(node)
821                 self.assertEqual(
822                     black.get_features_used(node, future_imports=future_imports),
823                     features,
824                 )
825
826     def test_get_future_imports(self) -> None:
827         node = black.lib2to3_parse("\n")
828         self.assertEqual(set(), black.get_future_imports(node))
829         node = black.lib2to3_parse("from __future__ import black\n")
830         self.assertEqual({"black"}, black.get_future_imports(node))
831         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
832         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
833         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
834         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
835         node = black.lib2to3_parse(
836             "from __future__ import multiple\nfrom __future__ import imports\n"
837         )
838         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
839         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
840         self.assertEqual({"black"}, black.get_future_imports(node))
841         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
842         self.assertEqual({"black"}, black.get_future_imports(node))
843         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
844         self.assertEqual(set(), black.get_future_imports(node))
845         node = black.lib2to3_parse("from some.module import black\n")
846         self.assertEqual(set(), black.get_future_imports(node))
847         node = black.lib2to3_parse(
848             "from __future__ import unicode_literals as _unicode_literals"
849         )
850         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
851         node = black.lib2to3_parse(
852             "from __future__ import unicode_literals as _lol, print"
853         )
854         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
855
856     @pytest.mark.incompatible_with_mypyc
857     def test_debug_visitor(self) -> None:
858         source, _ = read_data("miscellaneous", "debug_visitor")
859         expected, _ = read_data("miscellaneous", "debug_visitor.out")
860         out_lines = []
861         err_lines = []
862
863         def out(msg: str, **kwargs: Any) -> None:
864             out_lines.append(msg)
865
866         def err(msg: str, **kwargs: Any) -> None:
867             err_lines.append(msg)
868
869         with patch("black.debug.out", out):
870             DebugVisitor.show(source)
871         actual = "\n".join(out_lines) + "\n"
872         log_name = ""
873         if expected != actual:
874             log_name = black.dump_to_file(*out_lines)
875         self.assertEqual(
876             expected,
877             actual,
878             f"AST print out is different. Actual version dumped to {log_name}",
879         )
880
881     def test_format_file_contents(self) -> None:
882         empty = ""
883         mode = DEFAULT_MODE
884         with self.assertRaises(black.NothingChanged):
885             black.format_file_contents(empty, mode=mode, fast=False)
886         just_nl = "\n"
887         with self.assertRaises(black.NothingChanged):
888             black.format_file_contents(just_nl, mode=mode, fast=False)
889         same = "j = [1, 2, 3]\n"
890         with self.assertRaises(black.NothingChanged):
891             black.format_file_contents(same, mode=mode, fast=False)
892         different = "j = [1,2,3]"
893         expected = same
894         actual = black.format_file_contents(different, mode=mode, fast=False)
895         self.assertEqual(expected, actual)
896         invalid = "return if you can"
897         with self.assertRaises(black.InvalidInput) as e:
898             black.format_file_contents(invalid, mode=mode, fast=False)
899         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
900
901     def test_endmarker(self) -> None:
902         n = black.lib2to3_parse("\n")
903         self.assertEqual(n.type, black.syms.file_input)
904         self.assertEqual(len(n.children), 1)
905         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
906
907     @pytest.mark.incompatible_with_mypyc
908     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
909     def test_assertFormatEqual(self) -> None:
910         out_lines = []
911         err_lines = []
912
913         def out(msg: str, **kwargs: Any) -> None:
914             out_lines.append(msg)
915
916         def err(msg: str, **kwargs: Any) -> None:
917             err_lines.append(msg)
918
919         with patch("black.output._out", out), patch("black.output._err", err):
920             with self.assertRaises(AssertionError):
921                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
922
923         out_str = "".join(out_lines)
924         self.assertIn("Expected tree:", out_str)
925         self.assertIn("Actual tree:", out_str)
926         self.assertEqual("".join(err_lines), "")
927
928     @event_loop()
929     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
930     def test_works_in_mono_process_only_environment(self) -> None:
931         with cache_dir() as workspace:
932             for f in [
933                 (workspace / "one.py").resolve(),
934                 (workspace / "two.py").resolve(),
935             ]:
936                 f.write_text('print("hello")\n')
937             self.invokeBlack([str(workspace)])
938
939     @event_loop()
940     def test_check_diff_use_together(self) -> None:
941         with cache_dir():
942             # Files which will be reformatted.
943             src1 = get_case_path("miscellaneous", "string_quotes")
944             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
945             # Files which will not be reformatted.
946             src2 = get_case_path("simple_cases", "composition")
947             self.invokeBlack([str(src2), "--diff", "--check"])
948             # Multi file command.
949             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
950
951     def test_no_src_fails(self) -> None:
952         with cache_dir():
953             self.invokeBlack([], exit_code=1)
954
955     def test_src_and_code_fails(self) -> None:
956         with cache_dir():
957             self.invokeBlack([".", "-c", "0"], exit_code=1)
958
959     def test_broken_symlink(self) -> None:
960         with cache_dir() as workspace:
961             symlink = workspace / "broken_link.py"
962             try:
963                 symlink.symlink_to("nonexistent.py")
964             except (OSError, NotImplementedError) as e:
965                 self.skipTest(f"Can't create symlinks: {e}")
966             self.invokeBlack([str(workspace.resolve())])
967
968     def test_single_file_force_pyi(self) -> None:
969         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
970         contents, expected = read_data("miscellaneous", "force_pyi")
971         with cache_dir() as workspace:
972             path = (workspace / "file.py").resolve()
973             with open(path, "w") as fh:
974                 fh.write(contents)
975             self.invokeBlack([str(path), "--pyi"])
976             with open(path, "r") as fh:
977                 actual = fh.read()
978             # verify cache with --pyi is separate
979             pyi_cache = black.read_cache(pyi_mode)
980             self.assertIn(str(path), pyi_cache)
981             normal_cache = black.read_cache(DEFAULT_MODE)
982             self.assertNotIn(str(path), normal_cache)
983         self.assertFormatEqual(expected, actual)
984         black.assert_equivalent(contents, actual)
985         black.assert_stable(contents, actual, pyi_mode)
986
987     @event_loop()
988     def test_multi_file_force_pyi(self) -> None:
989         reg_mode = DEFAULT_MODE
990         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
991         contents, expected = read_data("miscellaneous", "force_pyi")
992         with cache_dir() as workspace:
993             paths = [
994                 (workspace / "file1.py").resolve(),
995                 (workspace / "file2.py").resolve(),
996             ]
997             for path in paths:
998                 with open(path, "w") as fh:
999                     fh.write(contents)
1000             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1001             for path in paths:
1002                 with open(path, "r") as fh:
1003                     actual = fh.read()
1004                 self.assertEqual(actual, expected)
1005             # verify cache with --pyi is separate
1006             pyi_cache = black.read_cache(pyi_mode)
1007             normal_cache = black.read_cache(reg_mode)
1008             for path in paths:
1009                 self.assertIn(str(path), pyi_cache)
1010                 self.assertNotIn(str(path), normal_cache)
1011
1012     def test_pipe_force_pyi(self) -> None:
1013         source, expected = read_data("miscellaneous", "force_pyi")
1014         result = CliRunner().invoke(
1015             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1016         )
1017         self.assertEqual(result.exit_code, 0)
1018         actual = result.output
1019         self.assertFormatEqual(actual, expected)
1020
1021     def test_single_file_force_py36(self) -> None:
1022         reg_mode = DEFAULT_MODE
1023         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1024         source, expected = read_data("miscellaneous", "force_py36")
1025         with cache_dir() as workspace:
1026             path = (workspace / "file.py").resolve()
1027             with open(path, "w") as fh:
1028                 fh.write(source)
1029             self.invokeBlack([str(path), *PY36_ARGS])
1030             with open(path, "r") as fh:
1031                 actual = fh.read()
1032             # verify cache with --target-version is separate
1033             py36_cache = black.read_cache(py36_mode)
1034             self.assertIn(str(path), py36_cache)
1035             normal_cache = black.read_cache(reg_mode)
1036             self.assertNotIn(str(path), normal_cache)
1037         self.assertEqual(actual, expected)
1038
1039     @event_loop()
1040     def test_multi_file_force_py36(self) -> None:
1041         reg_mode = DEFAULT_MODE
1042         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1043         source, expected = read_data("miscellaneous", "force_py36")
1044         with cache_dir() as workspace:
1045             paths = [
1046                 (workspace / "file1.py").resolve(),
1047                 (workspace / "file2.py").resolve(),
1048             ]
1049             for path in paths:
1050                 with open(path, "w") as fh:
1051                     fh.write(source)
1052             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1053             for path in paths:
1054                 with open(path, "r") as fh:
1055                     actual = fh.read()
1056                 self.assertEqual(actual, expected)
1057             # verify cache with --target-version is separate
1058             pyi_cache = black.read_cache(py36_mode)
1059             normal_cache = black.read_cache(reg_mode)
1060             for path in paths:
1061                 self.assertIn(str(path), pyi_cache)
1062                 self.assertNotIn(str(path), normal_cache)
1063
1064     def test_pipe_force_py36(self) -> None:
1065         source, expected = read_data("miscellaneous", "force_py36")
1066         result = CliRunner().invoke(
1067             black.main,
1068             ["-", "-q", "--target-version=py36"],
1069             input=BytesIO(source.encode("utf8")),
1070         )
1071         self.assertEqual(result.exit_code, 0)
1072         actual = result.output
1073         self.assertFormatEqual(actual, expected)
1074
1075     @pytest.mark.incompatible_with_mypyc
1076     def test_reformat_one_with_stdin(self) -> None:
1077         with patch(
1078             "black.format_stdin_to_stdout",
1079             return_value=lambda *args, **kwargs: black.Changed.YES,
1080         ) as fsts:
1081             report = MagicMock()
1082             path = Path("-")
1083             black.reformat_one(
1084                 path,
1085                 fast=True,
1086                 write_back=black.WriteBack.YES,
1087                 mode=DEFAULT_MODE,
1088                 report=report,
1089             )
1090             fsts.assert_called_once()
1091             report.done.assert_called_with(path, black.Changed.YES)
1092
1093     @pytest.mark.incompatible_with_mypyc
1094     def test_reformat_one_with_stdin_filename(self) -> None:
1095         with patch(
1096             "black.format_stdin_to_stdout",
1097             return_value=lambda *args, **kwargs: black.Changed.YES,
1098         ) as fsts:
1099             report = MagicMock()
1100             p = "foo.py"
1101             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1102             expected = Path(p)
1103             black.reformat_one(
1104                 path,
1105                 fast=True,
1106                 write_back=black.WriteBack.YES,
1107                 mode=DEFAULT_MODE,
1108                 report=report,
1109             )
1110             fsts.assert_called_once_with(
1111                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1112             )
1113             # __BLACK_STDIN_FILENAME__ should have been stripped
1114             report.done.assert_called_with(expected, black.Changed.YES)
1115
1116     @pytest.mark.incompatible_with_mypyc
1117     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1118         with patch(
1119             "black.format_stdin_to_stdout",
1120             return_value=lambda *args, **kwargs: black.Changed.YES,
1121         ) as fsts:
1122             report = MagicMock()
1123             p = "foo.pyi"
1124             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1125             expected = Path(p)
1126             black.reformat_one(
1127                 path,
1128                 fast=True,
1129                 write_back=black.WriteBack.YES,
1130                 mode=DEFAULT_MODE,
1131                 report=report,
1132             )
1133             fsts.assert_called_once_with(
1134                 fast=True,
1135                 write_back=black.WriteBack.YES,
1136                 mode=replace(DEFAULT_MODE, is_pyi=True),
1137             )
1138             # __BLACK_STDIN_FILENAME__ should have been stripped
1139             report.done.assert_called_with(expected, black.Changed.YES)
1140
1141     @pytest.mark.incompatible_with_mypyc
1142     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1143         with patch(
1144             "black.format_stdin_to_stdout",
1145             return_value=lambda *args, **kwargs: black.Changed.YES,
1146         ) as fsts:
1147             report = MagicMock()
1148             p = "foo.ipynb"
1149             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1150             expected = Path(p)
1151             black.reformat_one(
1152                 path,
1153                 fast=True,
1154                 write_back=black.WriteBack.YES,
1155                 mode=DEFAULT_MODE,
1156                 report=report,
1157             )
1158             fsts.assert_called_once_with(
1159                 fast=True,
1160                 write_back=black.WriteBack.YES,
1161                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1162             )
1163             # __BLACK_STDIN_FILENAME__ should have been stripped
1164             report.done.assert_called_with(expected, black.Changed.YES)
1165
1166     @pytest.mark.incompatible_with_mypyc
1167     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1168         with patch(
1169             "black.format_stdin_to_stdout",
1170             return_value=lambda *args, **kwargs: black.Changed.YES,
1171         ) as fsts:
1172             report = MagicMock()
1173             # Even with an existing file, since we are forcing stdin, black
1174             # should output to stdout and not modify the file inplace
1175             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1176             # Make sure is_file actually returns True
1177             self.assertTrue(p.is_file())
1178             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1179             expected = Path(p)
1180             black.reformat_one(
1181                 path,
1182                 fast=True,
1183                 write_back=black.WriteBack.YES,
1184                 mode=DEFAULT_MODE,
1185                 report=report,
1186             )
1187             fsts.assert_called_once()
1188             # __BLACK_STDIN_FILENAME__ should have been stripped
1189             report.done.assert_called_with(expected, black.Changed.YES)
1190
1191     def test_reformat_one_with_stdin_empty(self) -> None:
1192         output = io.StringIO()
1193         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1194             try:
1195                 black.format_stdin_to_stdout(
1196                     fast=True,
1197                     content="",
1198                     write_back=black.WriteBack.YES,
1199                     mode=DEFAULT_MODE,
1200                 )
1201             except io.UnsupportedOperation:
1202                 pass  # StringIO does not support detach
1203             assert output.getvalue() == ""
1204
1205     def test_invalid_cli_regex(self) -> None:
1206         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1207             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1208
1209     def test_required_version_matches_version(self) -> None:
1210         self.invokeBlack(
1211             ["--required-version", black.__version__, "-c", "0"],
1212             exit_code=0,
1213             ignore_config=True,
1214         )
1215
1216     def test_required_version_matches_partial_version(self) -> None:
1217         self.invokeBlack(
1218             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1219             exit_code=0,
1220             ignore_config=True,
1221         )
1222
1223     def test_required_version_does_not_match_on_minor_version(self) -> None:
1224         self.invokeBlack(
1225             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1226             exit_code=1,
1227             ignore_config=True,
1228         )
1229
1230     def test_required_version_does_not_match_version(self) -> None:
1231         result = BlackRunner().invoke(
1232             black.main,
1233             ["--required-version", "20.99b", "-c", "0"],
1234         )
1235         self.assertEqual(result.exit_code, 1)
1236         self.assertIn("required version", result.stderr)
1237
1238     def test_preserves_line_endings(self) -> None:
1239         with TemporaryDirectory() as workspace:
1240             test_file = Path(workspace) / "test.py"
1241             for nl in ["\n", "\r\n"]:
1242                 contents = nl.join(["def f(  ):", "    pass"])
1243                 test_file.write_bytes(contents.encode())
1244                 ff(test_file, write_back=black.WriteBack.YES)
1245                 updated_contents: bytes = test_file.read_bytes()
1246                 self.assertIn(nl.encode(), updated_contents)
1247                 if nl == "\n":
1248                     self.assertNotIn(b"\r\n", updated_contents)
1249
1250     def test_preserves_line_endings_via_stdin(self) -> None:
1251         for nl in ["\n", "\r\n"]:
1252             contents = nl.join(["def f(  ):", "    pass"])
1253             runner = BlackRunner()
1254             result = runner.invoke(
1255                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1256             )
1257             self.assertEqual(result.exit_code, 0)
1258             output = result.stdout_bytes
1259             self.assertIn(nl.encode("utf8"), output)
1260             if nl == "\n":
1261                 self.assertNotIn(b"\r\n", output)
1262
1263     def test_assert_equivalent_different_asts(self) -> None:
1264         with self.assertRaises(AssertionError):
1265             black.assert_equivalent("{}", "None")
1266
1267     def test_shhh_click(self) -> None:
1268         try:
1269             from click import _unicodefun  # type: ignore
1270         except ImportError:
1271             self.skipTest("Incompatible Click version")
1272
1273         if not hasattr(_unicodefun, "_verify_python_env"):
1274             self.skipTest("Incompatible Click version")
1275
1276         # First, let's see if Click is crashing with a preferred ASCII charset.
1277         with patch("locale.getpreferredencoding") as gpe:
1278             gpe.return_value = "ASCII"
1279             with self.assertRaises(RuntimeError):
1280                 _unicodefun._verify_python_env()
1281         # Now, let's silence Click...
1282         black.patch_click()
1283         # ...and confirm it's silent.
1284         with patch("locale.getpreferredencoding") as gpe:
1285             gpe.return_value = "ASCII"
1286             try:
1287                 _unicodefun._verify_python_env()
1288             except RuntimeError as re:
1289                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1290
1291     def test_root_logger_not_used_directly(self) -> None:
1292         def fail(*args: Any, **kwargs: Any) -> None:
1293             self.fail("Record created with root logger")
1294
1295         with patch.multiple(
1296             logging.root,
1297             debug=fail,
1298             info=fail,
1299             warning=fail,
1300             error=fail,
1301             critical=fail,
1302             log=fail,
1303         ):
1304             ff(THIS_DIR / "util.py")
1305
1306     def test_invalid_config_return_code(self) -> None:
1307         tmp_file = Path(black.dump_to_file())
1308         try:
1309             tmp_config = Path(black.dump_to_file())
1310             tmp_config.unlink()
1311             args = ["--config", str(tmp_config), str(tmp_file)]
1312             self.invokeBlack(args, exit_code=2, ignore_config=False)
1313         finally:
1314             tmp_file.unlink()
1315
1316     def test_parse_pyproject_toml(self) -> None:
1317         test_toml_file = THIS_DIR / "test.toml"
1318         config = black.parse_pyproject_toml(str(test_toml_file))
1319         self.assertEqual(config["verbose"], 1)
1320         self.assertEqual(config["check"], "no")
1321         self.assertEqual(config["diff"], "y")
1322         self.assertEqual(config["color"], True)
1323         self.assertEqual(config["line_length"], 79)
1324         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1325         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1326         self.assertEqual(config["exclude"], r"\.pyi?$")
1327         self.assertEqual(config["include"], r"\.py?$")
1328
1329     def test_read_pyproject_toml(self) -> None:
1330         test_toml_file = THIS_DIR / "test.toml"
1331         fake_ctx = FakeContext()
1332         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1333         config = fake_ctx.default_map
1334         self.assertEqual(config["verbose"], "1")
1335         self.assertEqual(config["check"], "no")
1336         self.assertEqual(config["diff"], "y")
1337         self.assertEqual(config["color"], "True")
1338         self.assertEqual(config["line_length"], "79")
1339         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1340         self.assertEqual(config["exclude"], r"\.pyi?$")
1341         self.assertEqual(config["include"], r"\.py?$")
1342
1343     @pytest.mark.incompatible_with_mypyc
1344     def test_find_project_root(self) -> None:
1345         with TemporaryDirectory() as workspace:
1346             root = Path(workspace)
1347             test_dir = root / "test"
1348             test_dir.mkdir()
1349
1350             src_dir = root / "src"
1351             src_dir.mkdir()
1352
1353             root_pyproject = root / "pyproject.toml"
1354             root_pyproject.touch()
1355             src_pyproject = src_dir / "pyproject.toml"
1356             src_pyproject.touch()
1357             src_python = src_dir / "foo.py"
1358             src_python.touch()
1359
1360             self.assertEqual(
1361                 black.find_project_root((src_dir, test_dir)),
1362                 (root.resolve(), "pyproject.toml"),
1363             )
1364             self.assertEqual(
1365                 black.find_project_root((src_dir,)),
1366                 (src_dir.resolve(), "pyproject.toml"),
1367             )
1368             self.assertEqual(
1369                 black.find_project_root((src_python,)),
1370                 (src_dir.resolve(), "pyproject.toml"),
1371             )
1372
1373     @patch(
1374         "black.files.find_user_pyproject_toml",
1375     )
1376     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1377         find_user_pyproject_toml.side_effect = RuntimeError()
1378
1379         with redirect_stderr(io.StringIO()) as stderr:
1380             result = black.files.find_pyproject_toml(
1381                 path_search_start=(str(Path.cwd().root),)
1382             )
1383
1384         assert result is None
1385         err = stderr.getvalue()
1386         assert "Ignoring user configuration" in err
1387
1388     @patch(
1389         "black.files.find_user_pyproject_toml",
1390         black.files.find_user_pyproject_toml.__wrapped__,
1391     )
1392     def test_find_user_pyproject_toml_linux(self) -> None:
1393         if system() == "Windows":
1394             return
1395
1396         # Test if XDG_CONFIG_HOME is checked
1397         with TemporaryDirectory() as workspace:
1398             tmp_user_config = Path(workspace) / "black"
1399             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1400                 self.assertEqual(
1401                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1402                 )
1403
1404         # Test fallback for XDG_CONFIG_HOME
1405         with patch.dict("os.environ"):
1406             os.environ.pop("XDG_CONFIG_HOME", None)
1407             fallback_user_config = Path("~/.config").expanduser() / "black"
1408             self.assertEqual(
1409                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1410             )
1411
1412     def test_find_user_pyproject_toml_windows(self) -> None:
1413         if system() != "Windows":
1414             return
1415
1416         user_config_path = Path.home() / ".black"
1417         self.assertEqual(
1418             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1419         )
1420
1421     def test_bpo_33660_workaround(self) -> None:
1422         if system() == "Windows":
1423             return
1424
1425         # https://bugs.python.org/issue33660
1426         root = Path("/")
1427         with change_directory(root):
1428             path = Path("workspace") / "project"
1429             report = black.Report(verbose=True)
1430             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1431             self.assertEqual(normalized_path, "workspace/project")
1432
1433     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1434         if system() != "Windows":
1435             return
1436
1437         with TemporaryDirectory() as workspace:
1438             root = Path(workspace)
1439             junction_dir = root / "junction"
1440             junction_target_outside_of_root = root / ".."
1441             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1442
1443             report = black.Report(verbose=True)
1444             normalized_path = black.normalize_path_maybe_ignore(
1445                 junction_dir, root, report
1446             )
1447             # Manually delete for Python < 3.8
1448             os.system(f"rmdir {junction_dir}")
1449
1450             self.assertEqual(normalized_path, None)
1451
1452     def test_newline_comment_interaction(self) -> None:
1453         source = "class A:\\\r\n# type: ignore\n pass\n"
1454         output = black.format_str(source, mode=DEFAULT_MODE)
1455         black.assert_stable(source, output, mode=DEFAULT_MODE)
1456
1457     def test_bpo_2142_workaround(self) -> None:
1458
1459         # https://bugs.python.org/issue2142
1460
1461         source, _ = read_data("miscellaneous", "missing_final_newline")
1462         # read_data adds a trailing newline
1463         source = source.rstrip()
1464         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1465         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1466         diff_header = re.compile(
1467             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1468             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1469         )
1470         try:
1471             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1472             self.assertEqual(result.exit_code, 0)
1473         finally:
1474             os.unlink(tmp_file)
1475         actual = result.output
1476         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1477         self.assertEqual(actual, expected)
1478
1479     @staticmethod
1480     def compare_results(
1481         result: click.testing.Result, expected_value: str, expected_exit_code: int
1482     ) -> None:
1483         """Helper method to test the value and exit code of a click Result."""
1484         assert (
1485             result.output == expected_value
1486         ), "The output did not match the expected value."
1487         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1488
1489     def test_code_option(self) -> None:
1490         """Test the code option with no changes."""
1491         code = 'print("Hello world")\n'
1492         args = ["--code", code]
1493         result = CliRunner().invoke(black.main, args)
1494
1495         self.compare_results(result, code, 0)
1496
1497     def test_code_option_changed(self) -> None:
1498         """Test the code option when changes are required."""
1499         code = "print('hello world')"
1500         formatted = black.format_str(code, mode=DEFAULT_MODE)
1501
1502         args = ["--code", code]
1503         result = CliRunner().invoke(black.main, args)
1504
1505         self.compare_results(result, formatted, 0)
1506
1507     def test_code_option_check(self) -> None:
1508         """Test the code option when check is passed."""
1509         args = ["--check", "--code", 'print("Hello world")\n']
1510         result = CliRunner().invoke(black.main, args)
1511         self.compare_results(result, "", 0)
1512
1513     def test_code_option_check_changed(self) -> None:
1514         """Test the code option when changes are required, and check is passed."""
1515         args = ["--check", "--code", "print('hello world')"]
1516         result = CliRunner().invoke(black.main, args)
1517         self.compare_results(result, "", 1)
1518
1519     def test_code_option_diff(self) -> None:
1520         """Test the code option when diff is passed."""
1521         code = "print('hello world')"
1522         formatted = black.format_str(code, mode=DEFAULT_MODE)
1523         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1524
1525         args = ["--diff", "--code", code]
1526         result = CliRunner().invoke(black.main, args)
1527
1528         # Remove time from diff
1529         output = DIFF_TIME.sub("", result.output)
1530
1531         assert output == result_diff, "The output did not match the expected value."
1532         assert result.exit_code == 0, "The exit code is incorrect."
1533
1534     def test_code_option_color_diff(self) -> None:
1535         """Test the code option when color and diff are passed."""
1536         code = "print('hello world')"
1537         formatted = black.format_str(code, mode=DEFAULT_MODE)
1538
1539         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1540         result_diff = color_diff(result_diff)
1541
1542         args = ["--diff", "--color", "--code", code]
1543         result = CliRunner().invoke(black.main, args)
1544
1545         # Remove time from diff
1546         output = DIFF_TIME.sub("", result.output)
1547
1548         assert output == result_diff, "The output did not match the expected value."
1549         assert result.exit_code == 0, "The exit code is incorrect."
1550
1551     @pytest.mark.incompatible_with_mypyc
1552     def test_code_option_safe(self) -> None:
1553         """Test that the code option throws an error when the sanity checks fail."""
1554         # Patch black.assert_equivalent to ensure the sanity checks fail
1555         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1556             code = 'print("Hello world")'
1557             error_msg = f"{code}\nerror: cannot format <string>: \n"
1558
1559             args = ["--safe", "--code", code]
1560             result = CliRunner().invoke(black.main, args)
1561
1562             self.compare_results(result, error_msg, 123)
1563
1564     def test_code_option_fast(self) -> None:
1565         """Test that the code option ignores errors when the sanity checks fail."""
1566         # Patch black.assert_equivalent to ensure the sanity checks fail
1567         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1568             code = 'print("Hello world")'
1569             formatted = black.format_str(code, mode=DEFAULT_MODE)
1570
1571             args = ["--fast", "--code", code]
1572             result = CliRunner().invoke(black.main, args)
1573
1574             self.compare_results(result, formatted, 0)
1575
1576     @pytest.mark.incompatible_with_mypyc
1577     def test_code_option_config(self) -> None:
1578         """
1579         Test that the code option finds the pyproject.toml in the current directory.
1580         """
1581         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1582             args = ["--code", "print"]
1583             # This is the only directory known to contain a pyproject.toml
1584             with change_directory(PROJECT_ROOT):
1585                 CliRunner().invoke(black.main, args)
1586                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1587
1588             assert (
1589                 len(parse.mock_calls) >= 1
1590             ), "Expected config parse to be called with the current directory."
1591
1592             _, call_args, _ = parse.mock_calls[0]
1593             assert (
1594                 call_args[0].lower() == str(pyproject_path).lower()
1595             ), "Incorrect config loaded."
1596
1597     @pytest.mark.incompatible_with_mypyc
1598     def test_code_option_parent_config(self) -> None:
1599         """
1600         Test that the code option finds the pyproject.toml in the parent directory.
1601         """
1602         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1603             with change_directory(THIS_DIR):
1604                 args = ["--code", "print"]
1605                 CliRunner().invoke(black.main, args)
1606
1607                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1608                 assert (
1609                     len(parse.mock_calls) >= 1
1610                 ), "Expected config parse to be called with the current directory."
1611
1612                 _, call_args, _ = parse.mock_calls[0]
1613                 assert (
1614                     call_args[0].lower() == str(pyproject_path).lower()
1615                 ), "Incorrect config loaded."
1616
1617     def test_for_handled_unexpected_eof_error(self) -> None:
1618         """
1619         Test that an unexpected EOF SyntaxError is nicely presented.
1620         """
1621         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1622             black.lib2to3_parse("print(", {})
1623
1624         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1625
1626     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1627         with pytest.raises(AssertionError) as err:
1628             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1629
1630         err.match("--safe")
1631         # Unfortunately the SyntaxError message has changed in newer versions so we
1632         # can't match it directly.
1633         err.match("invalid character")
1634         err.match(r"\(<unknown>, line 1\)")
1635
1636
1637 class TestCaching:
1638     def test_get_cache_dir(
1639         self,
1640         tmp_path: Path,
1641         monkeypatch: pytest.MonkeyPatch,
1642     ) -> None:
1643         # Create multiple cache directories
1644         workspace1 = tmp_path / "ws1"
1645         workspace1.mkdir()
1646         workspace2 = tmp_path / "ws2"
1647         workspace2.mkdir()
1648
1649         # Force user_cache_dir to use the temporary directory for easier assertions
1650         patch_user_cache_dir = patch(
1651             target="black.cache.user_cache_dir",
1652             autospec=True,
1653             return_value=str(workspace1),
1654         )
1655
1656         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1657         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1658         with patch_user_cache_dir:
1659             assert get_cache_dir() == workspace1
1660
1661         # If it is set, use the path provided in the env var.
1662         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1663         assert get_cache_dir() == workspace2
1664
1665     def test_cache_broken_file(self) -> None:
1666         mode = DEFAULT_MODE
1667         with cache_dir() as workspace:
1668             cache_file = get_cache_file(mode)
1669             cache_file.write_text("this is not a pickle")
1670             assert black.read_cache(mode) == {}
1671             src = (workspace / "test.py").resolve()
1672             src.write_text("print('hello')")
1673             invokeBlack([str(src)])
1674             cache = black.read_cache(mode)
1675             assert str(src) in cache
1676
1677     def test_cache_single_file_already_cached(self) -> None:
1678         mode = DEFAULT_MODE
1679         with cache_dir() as workspace:
1680             src = (workspace / "test.py").resolve()
1681             src.write_text("print('hello')")
1682             black.write_cache({}, [src], mode)
1683             invokeBlack([str(src)])
1684             assert src.read_text() == "print('hello')"
1685
1686     @event_loop()
1687     def test_cache_multiple_files(self) -> None:
1688         mode = DEFAULT_MODE
1689         with cache_dir() as workspace, patch(
1690             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1691         ):
1692             one = (workspace / "one.py").resolve()
1693             with one.open("w") as fobj:
1694                 fobj.write("print('hello')")
1695             two = (workspace / "two.py").resolve()
1696             with two.open("w") as fobj:
1697                 fobj.write("print('hello')")
1698             black.write_cache({}, [one], mode)
1699             invokeBlack([str(workspace)])
1700             with one.open("r") as fobj:
1701                 assert fobj.read() == "print('hello')"
1702             with two.open("r") as fobj:
1703                 assert fobj.read() == 'print("hello")\n'
1704             cache = black.read_cache(mode)
1705             assert str(one) in cache
1706             assert str(two) in cache
1707
1708     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1709     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1710         mode = DEFAULT_MODE
1711         with cache_dir() as workspace:
1712             src = (workspace / "test.py").resolve()
1713             with src.open("w") as fobj:
1714                 fobj.write("print('hello')")
1715             with patch("black.read_cache") as read_cache, patch(
1716                 "black.write_cache"
1717             ) as write_cache:
1718                 cmd = [str(src), "--diff"]
1719                 if color:
1720                     cmd.append("--color")
1721                 invokeBlack(cmd)
1722                 cache_file = get_cache_file(mode)
1723                 assert cache_file.exists() is False
1724                 write_cache.assert_not_called()
1725                 read_cache.assert_not_called()
1726
1727     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1728     @event_loop()
1729     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1730         with cache_dir() as workspace:
1731             for tag in range(0, 4):
1732                 src = (workspace / f"test{tag}.py").resolve()
1733                 with src.open("w") as fobj:
1734                     fobj.write("print('hello')")
1735             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1736                 cmd = ["--diff", str(workspace)]
1737                 if color:
1738                     cmd.append("--color")
1739                 invokeBlack(cmd, exit_code=0)
1740                 # this isn't quite doing what we want, but if it _isn't_
1741                 # called then we cannot be using the lock it provides
1742                 mgr.assert_called()
1743
1744     def test_no_cache_when_stdin(self) -> None:
1745         mode = DEFAULT_MODE
1746         with cache_dir():
1747             result = CliRunner().invoke(
1748                 black.main, ["-"], input=BytesIO(b"print('hello')")
1749             )
1750             assert not result.exit_code
1751             cache_file = get_cache_file(mode)
1752             assert not cache_file.exists()
1753
1754     def test_read_cache_no_cachefile(self) -> None:
1755         mode = DEFAULT_MODE
1756         with cache_dir():
1757             assert black.read_cache(mode) == {}
1758
1759     def test_write_cache_read_cache(self) -> None:
1760         mode = DEFAULT_MODE
1761         with cache_dir() as workspace:
1762             src = (workspace / "test.py").resolve()
1763             src.touch()
1764             black.write_cache({}, [src], mode)
1765             cache = black.read_cache(mode)
1766             assert str(src) in cache
1767             assert cache[str(src)] == black.get_cache_info(src)
1768
1769     def test_filter_cached(self) -> None:
1770         with TemporaryDirectory() as workspace:
1771             path = Path(workspace)
1772             uncached = (path / "uncached").resolve()
1773             cached = (path / "cached").resolve()
1774             cached_but_changed = (path / "changed").resolve()
1775             uncached.touch()
1776             cached.touch()
1777             cached_but_changed.touch()
1778             cache = {
1779                 str(cached): black.get_cache_info(cached),
1780                 str(cached_but_changed): (0.0, 0),
1781             }
1782             todo, done = black.filter_cached(
1783                 cache, {uncached, cached, cached_but_changed}
1784             )
1785             assert todo == {uncached, cached_but_changed}
1786             assert done == {cached}
1787
1788     def test_write_cache_creates_directory_if_needed(self) -> None:
1789         mode = DEFAULT_MODE
1790         with cache_dir(exists=False) as workspace:
1791             assert not workspace.exists()
1792             black.write_cache({}, [], mode)
1793             assert workspace.exists()
1794
1795     @event_loop()
1796     def test_failed_formatting_does_not_get_cached(self) -> None:
1797         mode = DEFAULT_MODE
1798         with cache_dir() as workspace, patch(
1799             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1800         ):
1801             failing = (workspace / "failing.py").resolve()
1802             with failing.open("w") as fobj:
1803                 fobj.write("not actually python")
1804             clean = (workspace / "clean.py").resolve()
1805             with clean.open("w") as fobj:
1806                 fobj.write('print("hello")\n')
1807             invokeBlack([str(workspace)], exit_code=123)
1808             cache = black.read_cache(mode)
1809             assert str(failing) not in cache
1810             assert str(clean) in cache
1811
1812     def test_write_cache_write_fail(self) -> None:
1813         mode = DEFAULT_MODE
1814         with cache_dir(), patch.object(Path, "open") as mock:
1815             mock.side_effect = OSError
1816             black.write_cache({}, [], mode)
1817
1818     def test_read_cache_line_lengths(self) -> None:
1819         mode = DEFAULT_MODE
1820         short_mode = replace(DEFAULT_MODE, line_length=1)
1821         with cache_dir() as workspace:
1822             path = (workspace / "file.py").resolve()
1823             path.touch()
1824             black.write_cache({}, [path], mode)
1825             one = black.read_cache(mode)
1826             assert str(path) in one
1827             two = black.read_cache(short_mode)
1828             assert str(path) not in two
1829
1830
1831 def assert_collected_sources(
1832     src: Sequence[Union[str, Path]],
1833     expected: Sequence[Union[str, Path]],
1834     *,
1835     ctx: Optional[FakeContext] = None,
1836     exclude: Optional[str] = None,
1837     include: Optional[str] = None,
1838     extend_exclude: Optional[str] = None,
1839     force_exclude: Optional[str] = None,
1840     stdin_filename: Optional[str] = None,
1841 ) -> None:
1842     gs_src = tuple(str(Path(s)) for s in src)
1843     gs_expected = [Path(s) for s in expected]
1844     gs_exclude = None if exclude is None else compile_pattern(exclude)
1845     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1846     gs_extend_exclude = (
1847         None if extend_exclude is None else compile_pattern(extend_exclude)
1848     )
1849     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1850     collected = black.get_sources(
1851         ctx=ctx or FakeContext(),
1852         src=gs_src,
1853         quiet=False,
1854         verbose=False,
1855         include=gs_include,
1856         exclude=gs_exclude,
1857         extend_exclude=gs_extend_exclude,
1858         force_exclude=gs_force_exclude,
1859         report=black.Report(),
1860         stdin_filename=stdin_filename,
1861     )
1862     assert sorted(collected) == sorted(gs_expected)
1863
1864
1865 class TestFileCollection:
1866     def test_include_exclude(self) -> None:
1867         path = THIS_DIR / "data" / "include_exclude_tests"
1868         src = [path]
1869         expected = [
1870             Path(path / "b/dont_exclude/a.py"),
1871             Path(path / "b/dont_exclude/a.pyi"),
1872         ]
1873         assert_collected_sources(
1874             src,
1875             expected,
1876             include=r"\.pyi?$",
1877             exclude=r"/exclude/|/\.definitely_exclude/",
1878         )
1879
1880     def test_gitignore_used_as_default(self) -> None:
1881         base = Path(DATA_DIR / "include_exclude_tests")
1882         expected = [
1883             base / "b/.definitely_exclude/a.py",
1884             base / "b/.definitely_exclude/a.pyi",
1885         ]
1886         src = [base / "b/"]
1887         ctx = FakeContext()
1888         ctx.obj["root"] = base
1889         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1890
1891     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1892     def test_exclude_for_issue_1572(self) -> None:
1893         # Exclude shouldn't touch files that were explicitly given to Black through the
1894         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1895         # https://github.com/psf/black/issues/1572
1896         path = DATA_DIR / "include_exclude_tests"
1897         src = [path / "b/exclude/a.py"]
1898         expected = [path / "b/exclude/a.py"]
1899         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1900
1901     def test_gitignore_exclude(self) -> None:
1902         path = THIS_DIR / "data" / "include_exclude_tests"
1903         include = re.compile(r"\.pyi?$")
1904         exclude = re.compile(r"")
1905         report = black.Report()
1906         gitignore = PathSpec.from_lines(
1907             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1908         )
1909         sources: List[Path] = []
1910         expected = [
1911             Path(path / "b/dont_exclude/a.py"),
1912             Path(path / "b/dont_exclude/a.pyi"),
1913         ]
1914         this_abs = THIS_DIR.resolve()
1915         sources.extend(
1916             black.gen_python_files(
1917                 path.iterdir(),
1918                 this_abs,
1919                 include,
1920                 exclude,
1921                 None,
1922                 None,
1923                 report,
1924                 gitignore,
1925                 verbose=False,
1926                 quiet=False,
1927             )
1928         )
1929         assert sorted(expected) == sorted(sources)
1930
1931     def test_nested_gitignore(self) -> None:
1932         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1933         include = re.compile(r"\.pyi?$")
1934         exclude = re.compile(r"")
1935         root_gitignore = black.files.get_gitignore(path)
1936         report = black.Report()
1937         expected: List[Path] = [
1938             Path(path / "x.py"),
1939             Path(path / "root/b.py"),
1940             Path(path / "root/c.py"),
1941             Path(path / "root/child/c.py"),
1942         ]
1943         this_abs = THIS_DIR.resolve()
1944         sources = list(
1945             black.gen_python_files(
1946                 path.iterdir(),
1947                 this_abs,
1948                 include,
1949                 exclude,
1950                 None,
1951                 None,
1952                 report,
1953                 root_gitignore,
1954                 verbose=False,
1955                 quiet=False,
1956             )
1957         )
1958         assert sorted(expected) == sorted(sources)
1959
1960     def test_invalid_gitignore(self) -> None:
1961         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1962         empty_config = path / "pyproject.toml"
1963         result = BlackRunner().invoke(
1964             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1965         )
1966         assert result.exit_code == 1
1967         assert result.stderr_bytes is not None
1968
1969         gitignore = path / ".gitignore"
1970         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1971
1972     def test_invalid_nested_gitignore(self) -> None:
1973         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1974         empty_config = path / "pyproject.toml"
1975         result = BlackRunner().invoke(
1976             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1977         )
1978         assert result.exit_code == 1
1979         assert result.stderr_bytes is not None
1980
1981         gitignore = path / "a" / ".gitignore"
1982         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1983
1984     def test_empty_include(self) -> None:
1985         path = DATA_DIR / "include_exclude_tests"
1986         src = [path]
1987         expected = [
1988             Path(path / "b/exclude/a.pie"),
1989             Path(path / "b/exclude/a.py"),
1990             Path(path / "b/exclude/a.pyi"),
1991             Path(path / "b/dont_exclude/a.pie"),
1992             Path(path / "b/dont_exclude/a.py"),
1993             Path(path / "b/dont_exclude/a.pyi"),
1994             Path(path / "b/.definitely_exclude/a.pie"),
1995             Path(path / "b/.definitely_exclude/a.py"),
1996             Path(path / "b/.definitely_exclude/a.pyi"),
1997             Path(path / ".gitignore"),
1998             Path(path / "pyproject.toml"),
1999         ]
2000         # Setting exclude explicitly to an empty string to block .gitignore usage.
2001         assert_collected_sources(src, expected, include="", exclude="")
2002
2003     def test_extend_exclude(self) -> None:
2004         path = DATA_DIR / "include_exclude_tests"
2005         src = [path]
2006         expected = [
2007             Path(path / "b/exclude/a.py"),
2008             Path(path / "b/dont_exclude/a.py"),
2009         ]
2010         assert_collected_sources(
2011             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2012         )
2013
2014     @pytest.mark.incompatible_with_mypyc
2015     def test_symlink_out_of_root_directory(self) -> None:
2016         path = MagicMock()
2017         root = THIS_DIR.resolve()
2018         child = MagicMock()
2019         include = re.compile(black.DEFAULT_INCLUDES)
2020         exclude = re.compile(black.DEFAULT_EXCLUDES)
2021         report = black.Report()
2022         gitignore = PathSpec.from_lines("gitwildmatch", [])
2023         # `child` should behave like a symlink which resolved path is clearly
2024         # outside of the `root` directory.
2025         path.iterdir.return_value = [child]
2026         child.resolve.return_value = Path("/a/b/c")
2027         child.as_posix.return_value = "/a/b/c"
2028         try:
2029             list(
2030                 black.gen_python_files(
2031                     path.iterdir(),
2032                     root,
2033                     include,
2034                     exclude,
2035                     None,
2036                     None,
2037                     report,
2038                     gitignore,
2039                     verbose=False,
2040                     quiet=False,
2041                 )
2042             )
2043         except ValueError as ve:
2044             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2045         path.iterdir.assert_called_once()
2046         child.resolve.assert_called_once()
2047
2048     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2049     def test_get_sources_with_stdin(self) -> None:
2050         src = ["-"]
2051         expected = ["-"]
2052         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2053
2054     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2055     def test_get_sources_with_stdin_filename(self) -> None:
2056         src = ["-"]
2057         stdin_filename = str(THIS_DIR / "data/collections.py")
2058         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2059         assert_collected_sources(
2060             src,
2061             expected,
2062             exclude=r"/exclude/a\.py",
2063             stdin_filename=stdin_filename,
2064         )
2065
2066     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2067     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2068         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2069         # file being passed directly. This is the same as
2070         # test_exclude_for_issue_1572
2071         path = DATA_DIR / "include_exclude_tests"
2072         src = ["-"]
2073         stdin_filename = str(path / "b/exclude/a.py")
2074         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2075         assert_collected_sources(
2076             src,
2077             expected,
2078             exclude=r"/exclude/|a\.py",
2079             stdin_filename=stdin_filename,
2080         )
2081
2082     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2083     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2084         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2085         # file being passed directly. This is the same as
2086         # test_exclude_for_issue_1572
2087         src = ["-"]
2088         path = THIS_DIR / "data" / "include_exclude_tests"
2089         stdin_filename = str(path / "b/exclude/a.py")
2090         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2091         assert_collected_sources(
2092             src,
2093             expected,
2094             extend_exclude=r"/exclude/|a\.py",
2095             stdin_filename=stdin_filename,
2096         )
2097
2098     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2099     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2100         # Force exclude should exclude the file when passing it through
2101         # stdin_filename
2102         path = THIS_DIR / "data" / "include_exclude_tests"
2103         stdin_filename = str(path / "b/exclude/a.py")
2104         assert_collected_sources(
2105             src=["-"],
2106             expected=[],
2107             force_exclude=r"/exclude/|a\.py",
2108             stdin_filename=stdin_filename,
2109         )
2110
2111
2112 try:
2113     with open(black.__file__, "r", encoding="utf-8") as _bf:
2114         black_source_lines = _bf.readlines()
2115 except UnicodeDecodeError:
2116     if not black.COMPILED:
2117         raise
2118
2119
2120 def tracefunc(
2121     frame: types.FrameType, event: str, arg: Any
2122 ) -> Callable[[types.FrameType, str, Any], Any]:
2123     """Show function calls `from black/__init__.py` as they happen.
2124
2125     Register this with `sys.settrace()` in a test you're debugging.
2126     """
2127     if event != "call":
2128         return tracefunc
2129
2130     stack = len(inspect.stack()) - 19
2131     stack *= 2
2132     filename = frame.f_code.co_filename
2133     lineno = frame.f_lineno
2134     func_sig_lineno = lineno - 1
2135     funcname = black_source_lines[func_sig_lineno].strip()
2136     while funcname.startswith("@"):
2137         func_sig_lineno += 1
2138         funcname = black_source_lines[func_sig_lineno].strip()
2139     if "black/__init__.py" in filename:
2140         print(f"{' ' * stack}{lineno}:{funcname}")
2141     return tracefunc