]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

5be4ae8533c249ff03829a0d35a6dcfdcb7f46ac
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args, catch_exceptions=False)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY36})
728
729         py2_only = "print x"
730         with self.assertRaises(black.InvalidInput):
731             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
732
733         py3_only = "exec(x, end=y)"
734         black.lib2to3_parse(py3_only)
735         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
736
737     def test_get_features_used_decorator(self) -> None:
738         # Test the feature detection of new decorator syntax
739         # since this makes some test cases of test_get_features_used()
740         # fails if it fails, this is tested first so that a useful case
741         # is identified
742         simples, relaxed = read_data("decorators")
743         # skip explanation comments at the top of the file
744         for simple_test in simples.split("##")[1:]:
745             node = black.lib2to3_parse(simple_test)
746             decorator = str(node.children[0].children[0]).strip()
747             self.assertNotIn(
748                 Feature.RELAXED_DECORATORS,
749                 black.get_features_used(node),
750                 msg=(
751                     f"decorator '{decorator}' follows python<=3.8 syntax"
752                     "but is detected as 3.9+"
753                     # f"The full node is\n{node!r}"
754                 ),
755             )
756         # skip the '# output' comment at the top of the output part
757         for relaxed_test in relaxed.split("##")[1:]:
758             node = black.lib2to3_parse(relaxed_test)
759             decorator = str(node.children[0].children[0]).strip()
760             self.assertIn(
761                 Feature.RELAXED_DECORATORS,
762                 black.get_features_used(node),
763                 msg=(
764                     f"decorator '{decorator}' uses python3.9+ syntax"
765                     "but is detected as python<=3.8"
766                     # f"The full node is\n{node!r}"
767                 ),
768             )
769
770     def test_get_features_used(self) -> None:
771         node = black.lib2to3_parse("def f(*, arg): ...\n")
772         self.assertEqual(black.get_features_used(node), set())
773         node = black.lib2to3_parse("def f(*, arg,): ...\n")
774         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
775         node = black.lib2to3_parse("f(*arg,)\n")
776         self.assertEqual(
777             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
778         )
779         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
780         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
781         node = black.lib2to3_parse("123_456\n")
782         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
783         node = black.lib2to3_parse("123456\n")
784         self.assertEqual(black.get_features_used(node), set())
785         source, expected = read_data("function")
786         node = black.lib2to3_parse(source)
787         expected_features = {
788             Feature.TRAILING_COMMA_IN_CALL,
789             Feature.TRAILING_COMMA_IN_DEF,
790             Feature.F_STRINGS,
791         }
792         self.assertEqual(black.get_features_used(node), expected_features)
793         node = black.lib2to3_parse(expected)
794         self.assertEqual(black.get_features_used(node), expected_features)
795         source, expected = read_data("expression")
796         node = black.lib2to3_parse(source)
797         self.assertEqual(black.get_features_used(node), set())
798         node = black.lib2to3_parse(expected)
799         self.assertEqual(black.get_features_used(node), set())
800         node = black.lib2to3_parse("lambda a, /, b: ...")
801         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
802         node = black.lib2to3_parse("def fn(a, /, b): ...")
803         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
804         node = black.lib2to3_parse("def fn(): yield a, b")
805         self.assertEqual(black.get_features_used(node), set())
806         node = black.lib2to3_parse("def fn(): return a, b")
807         self.assertEqual(black.get_features_used(node), set())
808         node = black.lib2to3_parse("def fn(): yield *b, c")
809         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
810         node = black.lib2to3_parse("def fn(): return a, *b, c")
811         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
812         node = black.lib2to3_parse("x = a, *b, c")
813         self.assertEqual(black.get_features_used(node), set())
814         node = black.lib2to3_parse("x: Any = regular")
815         self.assertEqual(black.get_features_used(node), set())
816         node = black.lib2to3_parse("x: Any = (regular, regular)")
817         self.assertEqual(black.get_features_used(node), set())
818         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
819         self.assertEqual(black.get_features_used(node), set())
820         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
821         self.assertEqual(
822             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
823         )
824
825     def test_get_features_used_for_future_flags(self) -> None:
826         for src, features in [
827             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
828             (
829                 "from __future__ import (other, annotations)",
830                 {Feature.FUTURE_ANNOTATIONS},
831             ),
832             ("a = 1 + 2\nfrom something import annotations", set()),
833             ("from __future__ import x, y", set()),
834         ]:
835             with self.subTest(src=src, features=features):
836                 node = black.lib2to3_parse(src)
837                 future_imports = black.get_future_imports(node)
838                 self.assertEqual(
839                     black.get_features_used(node, future_imports=future_imports),
840                     features,
841                 )
842
843     def test_get_future_imports(self) -> None:
844         node = black.lib2to3_parse("\n")
845         self.assertEqual(set(), black.get_future_imports(node))
846         node = black.lib2to3_parse("from __future__ import black\n")
847         self.assertEqual({"black"}, black.get_future_imports(node))
848         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
849         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
850         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
851         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
852         node = black.lib2to3_parse(
853             "from __future__ import multiple\nfrom __future__ import imports\n"
854         )
855         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
856         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
857         self.assertEqual({"black"}, black.get_future_imports(node))
858         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
859         self.assertEqual({"black"}, black.get_future_imports(node))
860         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
861         self.assertEqual(set(), black.get_future_imports(node))
862         node = black.lib2to3_parse("from some.module import black\n")
863         self.assertEqual(set(), black.get_future_imports(node))
864         node = black.lib2to3_parse(
865             "from __future__ import unicode_literals as _unicode_literals"
866         )
867         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
868         node = black.lib2to3_parse(
869             "from __future__ import unicode_literals as _lol, print"
870         )
871         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
872
873     @pytest.mark.incompatible_with_mypyc
874     def test_debug_visitor(self) -> None:
875         source, _ = read_data("debug_visitor.py")
876         expected, _ = read_data("debug_visitor.out")
877         out_lines = []
878         err_lines = []
879
880         def out(msg: str, **kwargs: Any) -> None:
881             out_lines.append(msg)
882
883         def err(msg: str, **kwargs: Any) -> None:
884             err_lines.append(msg)
885
886         with patch("black.debug.out", out):
887             DebugVisitor.show(source)
888         actual = "\n".join(out_lines) + "\n"
889         log_name = ""
890         if expected != actual:
891             log_name = black.dump_to_file(*out_lines)
892         self.assertEqual(
893             expected,
894             actual,
895             f"AST print out is different. Actual version dumped to {log_name}",
896         )
897
898     def test_format_file_contents(self) -> None:
899         empty = ""
900         mode = DEFAULT_MODE
901         with self.assertRaises(black.NothingChanged):
902             black.format_file_contents(empty, mode=mode, fast=False)
903         just_nl = "\n"
904         with self.assertRaises(black.NothingChanged):
905             black.format_file_contents(just_nl, mode=mode, fast=False)
906         same = "j = [1, 2, 3]\n"
907         with self.assertRaises(black.NothingChanged):
908             black.format_file_contents(same, mode=mode, fast=False)
909         different = "j = [1,2,3]"
910         expected = same
911         actual = black.format_file_contents(different, mode=mode, fast=False)
912         self.assertEqual(expected, actual)
913         invalid = "return if you can"
914         with self.assertRaises(black.InvalidInput) as e:
915             black.format_file_contents(invalid, mode=mode, fast=False)
916         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
917
918     def test_endmarker(self) -> None:
919         n = black.lib2to3_parse("\n")
920         self.assertEqual(n.type, black.syms.file_input)
921         self.assertEqual(len(n.children), 1)
922         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
923
924     @pytest.mark.incompatible_with_mypyc
925     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
926     def test_assertFormatEqual(self) -> None:
927         out_lines = []
928         err_lines = []
929
930         def out(msg: str, **kwargs: Any) -> None:
931             out_lines.append(msg)
932
933         def err(msg: str, **kwargs: Any) -> None:
934             err_lines.append(msg)
935
936         with patch("black.output._out", out), patch("black.output._err", err):
937             with self.assertRaises(AssertionError):
938                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
939
940         out_str = "".join(out_lines)
941         self.assertTrue("Expected tree:" in out_str)
942         self.assertTrue("Actual tree:" in out_str)
943         self.assertEqual("".join(err_lines), "")
944
945     @event_loop()
946     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
947     def test_works_in_mono_process_only_environment(self) -> None:
948         with cache_dir() as workspace:
949             for f in [
950                 (workspace / "one.py").resolve(),
951                 (workspace / "two.py").resolve(),
952             ]:
953                 f.write_text('print("hello")\n')
954             self.invokeBlack([str(workspace)])
955
956     @event_loop()
957     def test_check_diff_use_together(self) -> None:
958         with cache_dir():
959             # Files which will be reformatted.
960             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
961             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
962             # Files which will not be reformatted.
963             src2 = (THIS_DIR / "data" / "composition.py").resolve()
964             self.invokeBlack([str(src2), "--diff", "--check"])
965             # Multi file command.
966             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
967
968     def test_no_files(self) -> None:
969         with cache_dir():
970             # Without an argument, black exits with error code 0.
971             self.invokeBlack([])
972
973     def test_broken_symlink(self) -> None:
974         with cache_dir() as workspace:
975             symlink = workspace / "broken_link.py"
976             try:
977                 symlink.symlink_to("nonexistent.py")
978             except (OSError, NotImplementedError) as e:
979                 self.skipTest(f"Can't create symlinks: {e}")
980             self.invokeBlack([str(workspace.resolve())])
981
982     def test_single_file_force_pyi(self) -> None:
983         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
984         contents, expected = read_data("force_pyi")
985         with cache_dir() as workspace:
986             path = (workspace / "file.py").resolve()
987             with open(path, "w") as fh:
988                 fh.write(contents)
989             self.invokeBlack([str(path), "--pyi"])
990             with open(path, "r") as fh:
991                 actual = fh.read()
992             # verify cache with --pyi is separate
993             pyi_cache = black.read_cache(pyi_mode)
994             self.assertIn(str(path), pyi_cache)
995             normal_cache = black.read_cache(DEFAULT_MODE)
996             self.assertNotIn(str(path), normal_cache)
997         self.assertFormatEqual(expected, actual)
998         black.assert_equivalent(contents, actual)
999         black.assert_stable(contents, actual, pyi_mode)
1000
1001     @event_loop()
1002     def test_multi_file_force_pyi(self) -> None:
1003         reg_mode = DEFAULT_MODE
1004         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1005         contents, expected = read_data("force_pyi")
1006         with cache_dir() as workspace:
1007             paths = [
1008                 (workspace / "file1.py").resolve(),
1009                 (workspace / "file2.py").resolve(),
1010             ]
1011             for path in paths:
1012                 with open(path, "w") as fh:
1013                     fh.write(contents)
1014             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1015             for path in paths:
1016                 with open(path, "r") as fh:
1017                     actual = fh.read()
1018                 self.assertEqual(actual, expected)
1019             # verify cache with --pyi is separate
1020             pyi_cache = black.read_cache(pyi_mode)
1021             normal_cache = black.read_cache(reg_mode)
1022             for path in paths:
1023                 self.assertIn(str(path), pyi_cache)
1024                 self.assertNotIn(str(path), normal_cache)
1025
1026     def test_pipe_force_pyi(self) -> None:
1027         source, expected = read_data("force_pyi")
1028         result = CliRunner().invoke(
1029             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1030         )
1031         self.assertEqual(result.exit_code, 0)
1032         actual = result.output
1033         self.assertFormatEqual(actual, expected)
1034
1035     def test_single_file_force_py36(self) -> None:
1036         reg_mode = DEFAULT_MODE
1037         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1038         source, expected = read_data("force_py36")
1039         with cache_dir() as workspace:
1040             path = (workspace / "file.py").resolve()
1041             with open(path, "w") as fh:
1042                 fh.write(source)
1043             self.invokeBlack([str(path), *PY36_ARGS])
1044             with open(path, "r") as fh:
1045                 actual = fh.read()
1046             # verify cache with --target-version is separate
1047             py36_cache = black.read_cache(py36_mode)
1048             self.assertIn(str(path), py36_cache)
1049             normal_cache = black.read_cache(reg_mode)
1050             self.assertNotIn(str(path), normal_cache)
1051         self.assertEqual(actual, expected)
1052
1053     @event_loop()
1054     def test_multi_file_force_py36(self) -> None:
1055         reg_mode = DEFAULT_MODE
1056         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1057         source, expected = read_data("force_py36")
1058         with cache_dir() as workspace:
1059             paths = [
1060                 (workspace / "file1.py").resolve(),
1061                 (workspace / "file2.py").resolve(),
1062             ]
1063             for path in paths:
1064                 with open(path, "w") as fh:
1065                     fh.write(source)
1066             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1067             for path in paths:
1068                 with open(path, "r") as fh:
1069                     actual = fh.read()
1070                 self.assertEqual(actual, expected)
1071             # verify cache with --target-version is separate
1072             pyi_cache = black.read_cache(py36_mode)
1073             normal_cache = black.read_cache(reg_mode)
1074             for path in paths:
1075                 self.assertIn(str(path), pyi_cache)
1076                 self.assertNotIn(str(path), normal_cache)
1077
1078     def test_pipe_force_py36(self) -> None:
1079         source, expected = read_data("force_py36")
1080         result = CliRunner().invoke(
1081             black.main,
1082             ["-", "-q", "--target-version=py36"],
1083             input=BytesIO(source.encode("utf8")),
1084         )
1085         self.assertEqual(result.exit_code, 0)
1086         actual = result.output
1087         self.assertFormatEqual(actual, expected)
1088
1089     @pytest.mark.incompatible_with_mypyc
1090     def test_reformat_one_with_stdin(self) -> None:
1091         with patch(
1092             "black.format_stdin_to_stdout",
1093             return_value=lambda *args, **kwargs: black.Changed.YES,
1094         ) as fsts:
1095             report = MagicMock()
1096             path = Path("-")
1097             black.reformat_one(
1098                 path,
1099                 fast=True,
1100                 write_back=black.WriteBack.YES,
1101                 mode=DEFAULT_MODE,
1102                 report=report,
1103             )
1104             fsts.assert_called_once()
1105             report.done.assert_called_with(path, black.Changed.YES)
1106
1107     @pytest.mark.incompatible_with_mypyc
1108     def test_reformat_one_with_stdin_filename(self) -> None:
1109         with patch(
1110             "black.format_stdin_to_stdout",
1111             return_value=lambda *args, **kwargs: black.Changed.YES,
1112         ) as fsts:
1113             report = MagicMock()
1114             p = "foo.py"
1115             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1116             expected = Path(p)
1117             black.reformat_one(
1118                 path,
1119                 fast=True,
1120                 write_back=black.WriteBack.YES,
1121                 mode=DEFAULT_MODE,
1122                 report=report,
1123             )
1124             fsts.assert_called_once_with(
1125                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1126             )
1127             # __BLACK_STDIN_FILENAME__ should have been stripped
1128             report.done.assert_called_with(expected, black.Changed.YES)
1129
1130     @pytest.mark.incompatible_with_mypyc
1131     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1132         with patch(
1133             "black.format_stdin_to_stdout",
1134             return_value=lambda *args, **kwargs: black.Changed.YES,
1135         ) as fsts:
1136             report = MagicMock()
1137             p = "foo.pyi"
1138             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1139             expected = Path(p)
1140             black.reformat_one(
1141                 path,
1142                 fast=True,
1143                 write_back=black.WriteBack.YES,
1144                 mode=DEFAULT_MODE,
1145                 report=report,
1146             )
1147             fsts.assert_called_once_with(
1148                 fast=True,
1149                 write_back=black.WriteBack.YES,
1150                 mode=replace(DEFAULT_MODE, is_pyi=True),
1151             )
1152             # __BLACK_STDIN_FILENAME__ should have been stripped
1153             report.done.assert_called_with(expected, black.Changed.YES)
1154
1155     @pytest.mark.incompatible_with_mypyc
1156     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1157         with patch(
1158             "black.format_stdin_to_stdout",
1159             return_value=lambda *args, **kwargs: black.Changed.YES,
1160         ) as fsts:
1161             report = MagicMock()
1162             p = "foo.ipynb"
1163             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1164             expected = Path(p)
1165             black.reformat_one(
1166                 path,
1167                 fast=True,
1168                 write_back=black.WriteBack.YES,
1169                 mode=DEFAULT_MODE,
1170                 report=report,
1171             )
1172             fsts.assert_called_once_with(
1173                 fast=True,
1174                 write_back=black.WriteBack.YES,
1175                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1176             )
1177             # __BLACK_STDIN_FILENAME__ should have been stripped
1178             report.done.assert_called_with(expected, black.Changed.YES)
1179
1180     @pytest.mark.incompatible_with_mypyc
1181     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1182         with patch(
1183             "black.format_stdin_to_stdout",
1184             return_value=lambda *args, **kwargs: black.Changed.YES,
1185         ) as fsts:
1186             report = MagicMock()
1187             # Even with an existing file, since we are forcing stdin, black
1188             # should output to stdout and not modify the file inplace
1189             p = Path(str(THIS_DIR / "data/collections.py"))
1190             # Make sure is_file actually returns True
1191             self.assertTrue(p.is_file())
1192             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1193             expected = Path(p)
1194             black.reformat_one(
1195                 path,
1196                 fast=True,
1197                 write_back=black.WriteBack.YES,
1198                 mode=DEFAULT_MODE,
1199                 report=report,
1200             )
1201             fsts.assert_called_once()
1202             # __BLACK_STDIN_FILENAME__ should have been stripped
1203             report.done.assert_called_with(expected, black.Changed.YES)
1204
1205     def test_reformat_one_with_stdin_empty(self) -> None:
1206         output = io.StringIO()
1207         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1208             try:
1209                 black.format_stdin_to_stdout(
1210                     fast=True,
1211                     content="",
1212                     write_back=black.WriteBack.YES,
1213                     mode=DEFAULT_MODE,
1214                 )
1215             except io.UnsupportedOperation:
1216                 pass  # StringIO does not support detach
1217             assert output.getvalue() == ""
1218
1219     def test_invalid_cli_regex(self) -> None:
1220         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1221             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1222
1223     def test_required_version_matches_version(self) -> None:
1224         self.invokeBlack(
1225             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1226         )
1227
1228     def test_required_version_does_not_match_version(self) -> None:
1229         self.invokeBlack(
1230             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1231         )
1232
1233     def test_preserves_line_endings(self) -> None:
1234         with TemporaryDirectory() as workspace:
1235             test_file = Path(workspace) / "test.py"
1236             for nl in ["\n", "\r\n"]:
1237                 contents = nl.join(["def f(  ):", "    pass"])
1238                 test_file.write_bytes(contents.encode())
1239                 ff(test_file, write_back=black.WriteBack.YES)
1240                 updated_contents: bytes = test_file.read_bytes()
1241                 self.assertIn(nl.encode(), updated_contents)
1242                 if nl == "\n":
1243                     self.assertNotIn(b"\r\n", updated_contents)
1244
1245     def test_preserves_line_endings_via_stdin(self) -> None:
1246         for nl in ["\n", "\r\n"]:
1247             contents = nl.join(["def f(  ):", "    pass"])
1248             runner = BlackRunner()
1249             result = runner.invoke(
1250                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1251             )
1252             self.assertEqual(result.exit_code, 0)
1253             output = result.stdout_bytes
1254             self.assertIn(nl.encode("utf8"), output)
1255             if nl == "\n":
1256                 self.assertNotIn(b"\r\n", output)
1257
1258     def test_assert_equivalent_different_asts(self) -> None:
1259         with self.assertRaises(AssertionError):
1260             black.assert_equivalent("{}", "None")
1261
1262     def test_shhh_click(self) -> None:
1263         try:
1264             from click import _unicodefun
1265         except ModuleNotFoundError:
1266             self.skipTest("Incompatible Click version")
1267         if not hasattr(_unicodefun, "_verify_python3_env"):
1268             self.skipTest("Incompatible Click version")
1269         # First, let's see if Click is crashing with a preferred ASCII charset.
1270         with patch("locale.getpreferredencoding") as gpe:
1271             gpe.return_value = "ASCII"
1272             with self.assertRaises(RuntimeError):
1273                 _unicodefun._verify_python3_env()  # type: ignore
1274         # Now, let's silence Click...
1275         black.patch_click()
1276         # ...and confirm it's silent.
1277         with patch("locale.getpreferredencoding") as gpe:
1278             gpe.return_value = "ASCII"
1279             try:
1280                 _unicodefun._verify_python3_env()  # type: ignore
1281             except RuntimeError as re:
1282                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1283
1284     def test_root_logger_not_used_directly(self) -> None:
1285         def fail(*args: Any, **kwargs: Any) -> None:
1286             self.fail("Record created with root logger")
1287
1288         with patch.multiple(
1289             logging.root,
1290             debug=fail,
1291             info=fail,
1292             warning=fail,
1293             error=fail,
1294             critical=fail,
1295             log=fail,
1296         ):
1297             ff(THIS_DIR / "util.py")
1298
1299     def test_invalid_config_return_code(self) -> None:
1300         tmp_file = Path(black.dump_to_file())
1301         try:
1302             tmp_config = Path(black.dump_to_file())
1303             tmp_config.unlink()
1304             args = ["--config", str(tmp_config), str(tmp_file)]
1305             self.invokeBlack(args, exit_code=2, ignore_config=False)
1306         finally:
1307             tmp_file.unlink()
1308
1309     def test_parse_pyproject_toml(self) -> None:
1310         test_toml_file = THIS_DIR / "test.toml"
1311         config = black.parse_pyproject_toml(str(test_toml_file))
1312         self.assertEqual(config["verbose"], 1)
1313         self.assertEqual(config["check"], "no")
1314         self.assertEqual(config["diff"], "y")
1315         self.assertEqual(config["color"], True)
1316         self.assertEqual(config["line_length"], 79)
1317         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1318         self.assertEqual(config["exclude"], r"\.pyi?$")
1319         self.assertEqual(config["include"], r"\.py?$")
1320
1321     def test_read_pyproject_toml(self) -> None:
1322         test_toml_file = THIS_DIR / "test.toml"
1323         fake_ctx = FakeContext()
1324         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1325         config = fake_ctx.default_map
1326         self.assertEqual(config["verbose"], "1")
1327         self.assertEqual(config["check"], "no")
1328         self.assertEqual(config["diff"], "y")
1329         self.assertEqual(config["color"], "True")
1330         self.assertEqual(config["line_length"], "79")
1331         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1332         self.assertEqual(config["exclude"], r"\.pyi?$")
1333         self.assertEqual(config["include"], r"\.py?$")
1334
1335     @pytest.mark.incompatible_with_mypyc
1336     def test_find_project_root(self) -> None:
1337         with TemporaryDirectory() as workspace:
1338             root = Path(workspace)
1339             test_dir = root / "test"
1340             test_dir.mkdir()
1341
1342             src_dir = root / "src"
1343             src_dir.mkdir()
1344
1345             root_pyproject = root / "pyproject.toml"
1346             root_pyproject.touch()
1347             src_pyproject = src_dir / "pyproject.toml"
1348             src_pyproject.touch()
1349             src_python = src_dir / "foo.py"
1350             src_python.touch()
1351
1352             self.assertEqual(
1353                 black.find_project_root((src_dir, test_dir)), root.resolve()
1354             )
1355             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1356             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1357
1358     @patch(
1359         "black.files.find_user_pyproject_toml",
1360         black.files.find_user_pyproject_toml.__wrapped__,
1361     )
1362     def test_find_user_pyproject_toml_linux(self) -> None:
1363         if system() == "Windows":
1364             return
1365
1366         # Test if XDG_CONFIG_HOME is checked
1367         with TemporaryDirectory() as workspace:
1368             tmp_user_config = Path(workspace) / "black"
1369             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1370                 self.assertEqual(
1371                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1372                 )
1373
1374         # Test fallback for XDG_CONFIG_HOME
1375         with patch.dict("os.environ"):
1376             os.environ.pop("XDG_CONFIG_HOME", None)
1377             fallback_user_config = Path("~/.config").expanduser() / "black"
1378             self.assertEqual(
1379                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1380             )
1381
1382     def test_find_user_pyproject_toml_windows(self) -> None:
1383         if system() != "Windows":
1384             return
1385
1386         user_config_path = Path.home() / ".black"
1387         self.assertEqual(
1388             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1389         )
1390
1391     def test_bpo_33660_workaround(self) -> None:
1392         if system() == "Windows":
1393             return
1394
1395         # https://bugs.python.org/issue33660
1396         root = Path("/")
1397         with change_directory(root):
1398             path = Path("workspace") / "project"
1399             report = black.Report(verbose=True)
1400             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1401             self.assertEqual(normalized_path, "workspace/project")
1402
1403     def test_newline_comment_interaction(self) -> None:
1404         source = "class A:\\\r\n# type: ignore\n pass\n"
1405         output = black.format_str(source, mode=DEFAULT_MODE)
1406         black.assert_stable(source, output, mode=DEFAULT_MODE)
1407
1408     def test_bpo_2142_workaround(self) -> None:
1409
1410         # https://bugs.python.org/issue2142
1411
1412         source, _ = read_data("missing_final_newline.py")
1413         # read_data adds a trailing newline
1414         source = source.rstrip()
1415         expected, _ = read_data("missing_final_newline.diff")
1416         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1417         diff_header = re.compile(
1418             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1419             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1420         )
1421         try:
1422             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1423             self.assertEqual(result.exit_code, 0)
1424         finally:
1425             os.unlink(tmp_file)
1426         actual = result.output
1427         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1428         self.assertEqual(actual, expected)
1429
1430     @staticmethod
1431     def compare_results(
1432         result: click.testing.Result, expected_value: str, expected_exit_code: int
1433     ) -> None:
1434         """Helper method to test the value and exit code of a click Result."""
1435         assert (
1436             result.output == expected_value
1437         ), "The output did not match the expected value."
1438         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1439
1440     def test_code_option(self) -> None:
1441         """Test the code option with no changes."""
1442         code = 'print("Hello world")\n'
1443         args = ["--code", code]
1444         result = CliRunner().invoke(black.main, args)
1445
1446         self.compare_results(result, code, 0)
1447
1448     def test_code_option_changed(self) -> None:
1449         """Test the code option when changes are required."""
1450         code = "print('hello world')"
1451         formatted = black.format_str(code, mode=DEFAULT_MODE)
1452
1453         args = ["--code", code]
1454         result = CliRunner().invoke(black.main, args)
1455
1456         self.compare_results(result, formatted, 0)
1457
1458     def test_code_option_check(self) -> None:
1459         """Test the code option when check is passed."""
1460         args = ["--check", "--code", 'print("Hello world")\n']
1461         result = CliRunner().invoke(black.main, args)
1462         self.compare_results(result, "", 0)
1463
1464     def test_code_option_check_changed(self) -> None:
1465         """Test the code option when changes are required, and check is passed."""
1466         args = ["--check", "--code", "print('hello world')"]
1467         result = CliRunner().invoke(black.main, args)
1468         self.compare_results(result, "", 1)
1469
1470     def test_code_option_diff(self) -> None:
1471         """Test the code option when diff is passed."""
1472         code = "print('hello world')"
1473         formatted = black.format_str(code, mode=DEFAULT_MODE)
1474         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1475
1476         args = ["--diff", "--code", code]
1477         result = CliRunner().invoke(black.main, args)
1478
1479         # Remove time from diff
1480         output = DIFF_TIME.sub("", result.output)
1481
1482         assert output == result_diff, "The output did not match the expected value."
1483         assert result.exit_code == 0, "The exit code is incorrect."
1484
1485     def test_code_option_color_diff(self) -> None:
1486         """Test the code option when color and diff are passed."""
1487         code = "print('hello world')"
1488         formatted = black.format_str(code, mode=DEFAULT_MODE)
1489
1490         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1491         result_diff = color_diff(result_diff)
1492
1493         args = ["--diff", "--color", "--code", code]
1494         result = CliRunner().invoke(black.main, args)
1495
1496         # Remove time from diff
1497         output = DIFF_TIME.sub("", result.output)
1498
1499         assert output == result_diff, "The output did not match the expected value."
1500         assert result.exit_code == 0, "The exit code is incorrect."
1501
1502     @pytest.mark.incompatible_with_mypyc
1503     def test_code_option_safe(self) -> None:
1504         """Test that the code option throws an error when the sanity checks fail."""
1505         # Patch black.assert_equivalent to ensure the sanity checks fail
1506         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1507             code = 'print("Hello world")'
1508             error_msg = f"{code}\nerror: cannot format <string>: \n"
1509
1510             args = ["--safe", "--code", code]
1511             result = CliRunner().invoke(black.main, args)
1512
1513             self.compare_results(result, error_msg, 123)
1514
1515     def test_code_option_fast(self) -> None:
1516         """Test that the code option ignores errors when the sanity checks fail."""
1517         # Patch black.assert_equivalent to ensure the sanity checks fail
1518         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1519             code = 'print("Hello world")'
1520             formatted = black.format_str(code, mode=DEFAULT_MODE)
1521
1522             args = ["--fast", "--code", code]
1523             result = CliRunner().invoke(black.main, args)
1524
1525             self.compare_results(result, formatted, 0)
1526
1527     @pytest.mark.incompatible_with_mypyc
1528     def test_code_option_config(self) -> None:
1529         """
1530         Test that the code option finds the pyproject.toml in the current directory.
1531         """
1532         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1533             args = ["--code", "print"]
1534             # This is the only directory known to contain a pyproject.toml
1535             with change_directory(PROJECT_ROOT):
1536                 CliRunner().invoke(black.main, args)
1537                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1538
1539             assert (
1540                 len(parse.mock_calls) >= 1
1541             ), "Expected config parse to be called with the current directory."
1542
1543             _, call_args, _ = parse.mock_calls[0]
1544             assert (
1545                 call_args[0].lower() == str(pyproject_path).lower()
1546             ), "Incorrect config loaded."
1547
1548     @pytest.mark.incompatible_with_mypyc
1549     def test_code_option_parent_config(self) -> None:
1550         """
1551         Test that the code option finds the pyproject.toml in the parent directory.
1552         """
1553         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1554             with change_directory(THIS_DIR):
1555                 args = ["--code", "print"]
1556                 CliRunner().invoke(black.main, args)
1557
1558                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1559                 assert (
1560                     len(parse.mock_calls) >= 1
1561                 ), "Expected config parse to be called with the current directory."
1562
1563                 _, call_args, _ = parse.mock_calls[0]
1564                 assert (
1565                     call_args[0].lower() == str(pyproject_path).lower()
1566                 ), "Incorrect config loaded."
1567
1568     def test_for_handled_unexpected_eof_error(self) -> None:
1569         """
1570         Test that an unexpected EOF SyntaxError is nicely presented.
1571         """
1572         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1573             black.lib2to3_parse("print(", {})
1574
1575         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1576
1577     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1578         with pytest.raises(AssertionError) as err:
1579             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1580
1581         err.match("--safe")
1582         # Unfortunately the SyntaxError message has changed in newer versions so we
1583         # can't match it directly.
1584         err.match("invalid character")
1585         err.match(r"\(<unknown>, line 1\)")
1586
1587
1588 class TestCaching:
1589     def test_cache_broken_file(self) -> None:
1590         mode = DEFAULT_MODE
1591         with cache_dir() as workspace:
1592             cache_file = get_cache_file(mode)
1593             cache_file.write_text("this is not a pickle")
1594             assert black.read_cache(mode) == {}
1595             src = (workspace / "test.py").resolve()
1596             src.write_text("print('hello')")
1597             invokeBlack([str(src)])
1598             cache = black.read_cache(mode)
1599             assert str(src) in cache
1600
1601     def test_cache_single_file_already_cached(self) -> None:
1602         mode = DEFAULT_MODE
1603         with cache_dir() as workspace:
1604             src = (workspace / "test.py").resolve()
1605             src.write_text("print('hello')")
1606             black.write_cache({}, [src], mode)
1607             invokeBlack([str(src)])
1608             assert src.read_text() == "print('hello')"
1609
1610     @event_loop()
1611     def test_cache_multiple_files(self) -> None:
1612         mode = DEFAULT_MODE
1613         with cache_dir() as workspace, patch(
1614             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1615         ):
1616             one = (workspace / "one.py").resolve()
1617             with one.open("w") as fobj:
1618                 fobj.write("print('hello')")
1619             two = (workspace / "two.py").resolve()
1620             with two.open("w") as fobj:
1621                 fobj.write("print('hello')")
1622             black.write_cache({}, [one], mode)
1623             invokeBlack([str(workspace)])
1624             with one.open("r") as fobj:
1625                 assert fobj.read() == "print('hello')"
1626             with two.open("r") as fobj:
1627                 assert fobj.read() == 'print("hello")\n'
1628             cache = black.read_cache(mode)
1629             assert str(one) in cache
1630             assert str(two) in cache
1631
1632     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1633     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1634         mode = DEFAULT_MODE
1635         with cache_dir() as workspace:
1636             src = (workspace / "test.py").resolve()
1637             with src.open("w") as fobj:
1638                 fobj.write("print('hello')")
1639             with patch("black.read_cache") as read_cache, patch(
1640                 "black.write_cache"
1641             ) as write_cache:
1642                 cmd = [str(src), "--diff"]
1643                 if color:
1644                     cmd.append("--color")
1645                 invokeBlack(cmd)
1646                 cache_file = get_cache_file(mode)
1647                 assert cache_file.exists() is False
1648                 write_cache.assert_not_called()
1649                 read_cache.assert_not_called()
1650
1651     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1652     @event_loop()
1653     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1654         with cache_dir() as workspace:
1655             for tag in range(0, 4):
1656                 src = (workspace / f"test{tag}.py").resolve()
1657                 with src.open("w") as fobj:
1658                     fobj.write("print('hello')")
1659             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1660                 cmd = ["--diff", str(workspace)]
1661                 if color:
1662                     cmd.append("--color")
1663                 invokeBlack(cmd, exit_code=0)
1664                 # this isn't quite doing what we want, but if it _isn't_
1665                 # called then we cannot be using the lock it provides
1666                 mgr.assert_called()
1667
1668     def test_no_cache_when_stdin(self) -> None:
1669         mode = DEFAULT_MODE
1670         with cache_dir():
1671             result = CliRunner().invoke(
1672                 black.main, ["-"], input=BytesIO(b"print('hello')")
1673             )
1674             assert not result.exit_code
1675             cache_file = get_cache_file(mode)
1676             assert not cache_file.exists()
1677
1678     def test_read_cache_no_cachefile(self) -> None:
1679         mode = DEFAULT_MODE
1680         with cache_dir():
1681             assert black.read_cache(mode) == {}
1682
1683     def test_write_cache_read_cache(self) -> None:
1684         mode = DEFAULT_MODE
1685         with cache_dir() as workspace:
1686             src = (workspace / "test.py").resolve()
1687             src.touch()
1688             black.write_cache({}, [src], mode)
1689             cache = black.read_cache(mode)
1690             assert str(src) in cache
1691             assert cache[str(src)] == black.get_cache_info(src)
1692
1693     def test_filter_cached(self) -> None:
1694         with TemporaryDirectory() as workspace:
1695             path = Path(workspace)
1696             uncached = (path / "uncached").resolve()
1697             cached = (path / "cached").resolve()
1698             cached_but_changed = (path / "changed").resolve()
1699             uncached.touch()
1700             cached.touch()
1701             cached_but_changed.touch()
1702             cache = {
1703                 str(cached): black.get_cache_info(cached),
1704                 str(cached_but_changed): (0.0, 0),
1705             }
1706             todo, done = black.filter_cached(
1707                 cache, {uncached, cached, cached_but_changed}
1708             )
1709             assert todo == {uncached, cached_but_changed}
1710             assert done == {cached}
1711
1712     def test_write_cache_creates_directory_if_needed(self) -> None:
1713         mode = DEFAULT_MODE
1714         with cache_dir(exists=False) as workspace:
1715             assert not workspace.exists()
1716             black.write_cache({}, [], mode)
1717             assert workspace.exists()
1718
1719     @event_loop()
1720     def test_failed_formatting_does_not_get_cached(self) -> None:
1721         mode = DEFAULT_MODE
1722         with cache_dir() as workspace, patch(
1723             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1724         ):
1725             failing = (workspace / "failing.py").resolve()
1726             with failing.open("w") as fobj:
1727                 fobj.write("not actually python")
1728             clean = (workspace / "clean.py").resolve()
1729             with clean.open("w") as fobj:
1730                 fobj.write('print("hello")\n')
1731             invokeBlack([str(workspace)], exit_code=123)
1732             cache = black.read_cache(mode)
1733             assert str(failing) not in cache
1734             assert str(clean) in cache
1735
1736     def test_write_cache_write_fail(self) -> None:
1737         mode = DEFAULT_MODE
1738         with cache_dir(), patch.object(Path, "open") as mock:
1739             mock.side_effect = OSError
1740             black.write_cache({}, [], mode)
1741
1742     def test_read_cache_line_lengths(self) -> None:
1743         mode = DEFAULT_MODE
1744         short_mode = replace(DEFAULT_MODE, line_length=1)
1745         with cache_dir() as workspace:
1746             path = (workspace / "file.py").resolve()
1747             path.touch()
1748             black.write_cache({}, [path], mode)
1749             one = black.read_cache(mode)
1750             assert str(path) in one
1751             two = black.read_cache(short_mode)
1752             assert str(path) not in two
1753
1754
1755 def assert_collected_sources(
1756     src: Sequence[Union[str, Path]],
1757     expected: Sequence[Union[str, Path]],
1758     *,
1759     exclude: Optional[str] = None,
1760     include: Optional[str] = None,
1761     extend_exclude: Optional[str] = None,
1762     force_exclude: Optional[str] = None,
1763     stdin_filename: Optional[str] = None,
1764 ) -> None:
1765     gs_src = tuple(str(Path(s)) for s in src)
1766     gs_expected = [Path(s) for s in expected]
1767     gs_exclude = None if exclude is None else compile_pattern(exclude)
1768     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1769     gs_extend_exclude = (
1770         None if extend_exclude is None else compile_pattern(extend_exclude)
1771     )
1772     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1773     collected = black.get_sources(
1774         ctx=FakeContext(),
1775         src=gs_src,
1776         quiet=False,
1777         verbose=False,
1778         include=gs_include,
1779         exclude=gs_exclude,
1780         extend_exclude=gs_extend_exclude,
1781         force_exclude=gs_force_exclude,
1782         report=black.Report(),
1783         stdin_filename=stdin_filename,
1784     )
1785     assert sorted(collected) == sorted(gs_expected)
1786
1787
1788 class TestFileCollection:
1789     def test_include_exclude(self) -> None:
1790         path = THIS_DIR / "data" / "include_exclude_tests"
1791         src = [path]
1792         expected = [
1793             Path(path / "b/dont_exclude/a.py"),
1794             Path(path / "b/dont_exclude/a.pyi"),
1795         ]
1796         assert_collected_sources(
1797             src,
1798             expected,
1799             include=r"\.pyi?$",
1800             exclude=r"/exclude/|/\.definitely_exclude/",
1801         )
1802
1803     def test_gitignore_used_as_default(self) -> None:
1804         base = Path(DATA_DIR / "include_exclude_tests")
1805         expected = [
1806             base / "b/.definitely_exclude/a.py",
1807             base / "b/.definitely_exclude/a.pyi",
1808         ]
1809         src = [base / "b/"]
1810         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1811
1812     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1813     def test_exclude_for_issue_1572(self) -> None:
1814         # Exclude shouldn't touch files that were explicitly given to Black through the
1815         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1816         # https://github.com/psf/black/issues/1572
1817         path = DATA_DIR / "include_exclude_tests"
1818         src = [path / "b/exclude/a.py"]
1819         expected = [path / "b/exclude/a.py"]
1820         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1821
1822     def test_gitignore_exclude(self) -> None:
1823         path = THIS_DIR / "data" / "include_exclude_tests"
1824         include = re.compile(r"\.pyi?$")
1825         exclude = re.compile(r"")
1826         report = black.Report()
1827         gitignore = PathSpec.from_lines(
1828             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1829         )
1830         sources: List[Path] = []
1831         expected = [
1832             Path(path / "b/dont_exclude/a.py"),
1833             Path(path / "b/dont_exclude/a.pyi"),
1834         ]
1835         this_abs = THIS_DIR.resolve()
1836         sources.extend(
1837             black.gen_python_files(
1838                 path.iterdir(),
1839                 this_abs,
1840                 include,
1841                 exclude,
1842                 None,
1843                 None,
1844                 report,
1845                 gitignore,
1846                 verbose=False,
1847                 quiet=False,
1848             )
1849         )
1850         assert sorted(expected) == sorted(sources)
1851
1852     def test_nested_gitignore(self) -> None:
1853         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1854         include = re.compile(r"\.pyi?$")
1855         exclude = re.compile(r"")
1856         root_gitignore = black.files.get_gitignore(path)
1857         report = black.Report()
1858         expected: List[Path] = [
1859             Path(path / "x.py"),
1860             Path(path / "root/b.py"),
1861             Path(path / "root/c.py"),
1862             Path(path / "root/child/c.py"),
1863         ]
1864         this_abs = THIS_DIR.resolve()
1865         sources = list(
1866             black.gen_python_files(
1867                 path.iterdir(),
1868                 this_abs,
1869                 include,
1870                 exclude,
1871                 None,
1872                 None,
1873                 report,
1874                 root_gitignore,
1875                 verbose=False,
1876                 quiet=False,
1877             )
1878         )
1879         assert sorted(expected) == sorted(sources)
1880
1881     def test_invalid_gitignore(self) -> None:
1882         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1883         empty_config = path / "pyproject.toml"
1884         result = BlackRunner().invoke(
1885             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1886         )
1887         assert result.exit_code == 1
1888         assert result.stderr_bytes is not None
1889
1890         gitignore = path / ".gitignore"
1891         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1892
1893     def test_invalid_nested_gitignore(self) -> None:
1894         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1895         empty_config = path / "pyproject.toml"
1896         result = BlackRunner().invoke(
1897             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1898         )
1899         assert result.exit_code == 1
1900         assert result.stderr_bytes is not None
1901
1902         gitignore = path / "a" / ".gitignore"
1903         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1904
1905     def test_empty_include(self) -> None:
1906         path = DATA_DIR / "include_exclude_tests"
1907         src = [path]
1908         expected = [
1909             Path(path / "b/exclude/a.pie"),
1910             Path(path / "b/exclude/a.py"),
1911             Path(path / "b/exclude/a.pyi"),
1912             Path(path / "b/dont_exclude/a.pie"),
1913             Path(path / "b/dont_exclude/a.py"),
1914             Path(path / "b/dont_exclude/a.pyi"),
1915             Path(path / "b/.definitely_exclude/a.pie"),
1916             Path(path / "b/.definitely_exclude/a.py"),
1917             Path(path / "b/.definitely_exclude/a.pyi"),
1918             Path(path / ".gitignore"),
1919             Path(path / "pyproject.toml"),
1920         ]
1921         # Setting exclude explicitly to an empty string to block .gitignore usage.
1922         assert_collected_sources(src, expected, include="", exclude="")
1923
1924     def test_extend_exclude(self) -> None:
1925         path = DATA_DIR / "include_exclude_tests"
1926         src = [path]
1927         expected = [
1928             Path(path / "b/exclude/a.py"),
1929             Path(path / "b/dont_exclude/a.py"),
1930         ]
1931         assert_collected_sources(
1932             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1933         )
1934
1935     @pytest.mark.incompatible_with_mypyc
1936     def test_symlink_out_of_root_directory(self) -> None:
1937         path = MagicMock()
1938         root = THIS_DIR.resolve()
1939         child = MagicMock()
1940         include = re.compile(black.DEFAULT_INCLUDES)
1941         exclude = re.compile(black.DEFAULT_EXCLUDES)
1942         report = black.Report()
1943         gitignore = PathSpec.from_lines("gitwildmatch", [])
1944         # `child` should behave like a symlink which resolved path is clearly
1945         # outside of the `root` directory.
1946         path.iterdir.return_value = [child]
1947         child.resolve.return_value = Path("/a/b/c")
1948         child.as_posix.return_value = "/a/b/c"
1949         child.is_symlink.return_value = True
1950         try:
1951             list(
1952                 black.gen_python_files(
1953                     path.iterdir(),
1954                     root,
1955                     include,
1956                     exclude,
1957                     None,
1958                     None,
1959                     report,
1960                     gitignore,
1961                     verbose=False,
1962                     quiet=False,
1963                 )
1964             )
1965         except ValueError as ve:
1966             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1967         path.iterdir.assert_called_once()
1968         child.resolve.assert_called_once()
1969         child.is_symlink.assert_called_once()
1970         # `child` should behave like a strange file which resolved path is clearly
1971         # outside of the `root` directory.
1972         child.is_symlink.return_value = False
1973         with pytest.raises(ValueError):
1974             list(
1975                 black.gen_python_files(
1976                     path.iterdir(),
1977                     root,
1978                     include,
1979                     exclude,
1980                     None,
1981                     None,
1982                     report,
1983                     gitignore,
1984                     verbose=False,
1985                     quiet=False,
1986                 )
1987             )
1988         path.iterdir.assert_called()
1989         assert path.iterdir.call_count == 2
1990         child.resolve.assert_called()
1991         assert child.resolve.call_count == 2
1992         child.is_symlink.assert_called()
1993         assert child.is_symlink.call_count == 2
1994
1995     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1996     def test_get_sources_with_stdin(self) -> None:
1997         src = ["-"]
1998         expected = ["-"]
1999         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2000
2001     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2002     def test_get_sources_with_stdin_filename(self) -> None:
2003         src = ["-"]
2004         stdin_filename = str(THIS_DIR / "data/collections.py")
2005         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2006         assert_collected_sources(
2007             src,
2008             expected,
2009             exclude=r"/exclude/a\.py",
2010             stdin_filename=stdin_filename,
2011         )
2012
2013     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2014     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2015         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2016         # file being passed directly. This is the same as
2017         # test_exclude_for_issue_1572
2018         path = DATA_DIR / "include_exclude_tests"
2019         src = ["-"]
2020         stdin_filename = str(path / "b/exclude/a.py")
2021         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2022         assert_collected_sources(
2023             src,
2024             expected,
2025             exclude=r"/exclude/|a\.py",
2026             stdin_filename=stdin_filename,
2027         )
2028
2029     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2030     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2031         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2032         # file being passed directly. This is the same as
2033         # test_exclude_for_issue_1572
2034         src = ["-"]
2035         path = THIS_DIR / "data" / "include_exclude_tests"
2036         stdin_filename = str(path / "b/exclude/a.py")
2037         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2038         assert_collected_sources(
2039             src,
2040             expected,
2041             extend_exclude=r"/exclude/|a\.py",
2042             stdin_filename=stdin_filename,
2043         )
2044
2045     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2046     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2047         # Force exclude should exclude the file when passing it through
2048         # stdin_filename
2049         path = THIS_DIR / "data" / "include_exclude_tests"
2050         stdin_filename = str(path / "b/exclude/a.py")
2051         assert_collected_sources(
2052             src=["-"],
2053             expected=[],
2054             force_exclude=r"/exclude/|a\.py",
2055             stdin_filename=stdin_filename,
2056         )
2057
2058
2059 try:
2060     with open(black.__file__, "r", encoding="utf-8") as _bf:
2061         black_source_lines = _bf.readlines()
2062 except UnicodeDecodeError:
2063     if not black.COMPILED:
2064         raise
2065
2066
2067 def tracefunc(
2068     frame: types.FrameType, event: str, arg: Any
2069 ) -> Callable[[types.FrameType, str, Any], Any]:
2070     """Show function calls `from black/__init__.py` as they happen.
2071
2072     Register this with `sys.settrace()` in a test you're debugging.
2073     """
2074     if event != "call":
2075         return tracefunc
2076
2077     stack = len(inspect.stack()) - 19
2078     stack *= 2
2079     filename = frame.f_code.co_filename
2080     lineno = frame.f_lineno
2081     func_sig_lineno = lineno - 1
2082     funcname = black_source_lines[func_sig_lineno].strip()
2083     while funcname.startswith("@"):
2084         func_sig_lineno += 1
2085         funcname = black_source_lines[func_sig_lineno].strip()
2086     if "black/__init__.py" in filename:
2087         print(f"{' ' * stack}{lineno}:{funcname}")
2088     return tracefunc