]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

5647a00e48ba99a38184f180faa3ee578c8efc02
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import regex as re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1;37m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY27})
728         black.lib2to3_parse(straddling, {TargetVersion.PY36})
729         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
730
731         py2_only = "print x"
732         black.lib2to3_parse(py2_only)
733         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
734         with self.assertRaises(black.InvalidInput):
735             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
736         with self.assertRaises(black.InvalidInput):
737             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
738
739         py3_only = "exec(x, end=y)"
740         black.lib2to3_parse(py3_only)
741         with self.assertRaises(black.InvalidInput):
742             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
743         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
744         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
745
746     def test_get_features_used_decorator(self) -> None:
747         # Test the feature detection of new decorator syntax
748         # since this makes some test cases of test_get_features_used()
749         # fails if it fails, this is tested first so that a useful case
750         # is identified
751         simples, relaxed = read_data("decorators")
752         # skip explanation comments at the top of the file
753         for simple_test in simples.split("##")[1:]:
754             node = black.lib2to3_parse(simple_test)
755             decorator = str(node.children[0].children[0]).strip()
756             self.assertNotIn(
757                 Feature.RELAXED_DECORATORS,
758                 black.get_features_used(node),
759                 msg=(
760                     f"decorator '{decorator}' follows python<=3.8 syntax"
761                     "but is detected as 3.9+"
762                     # f"The full node is\n{node!r}"
763                 ),
764             )
765         # skip the '# output' comment at the top of the output part
766         for relaxed_test in relaxed.split("##")[1:]:
767             node = black.lib2to3_parse(relaxed_test)
768             decorator = str(node.children[0].children[0]).strip()
769             self.assertIn(
770                 Feature.RELAXED_DECORATORS,
771                 black.get_features_used(node),
772                 msg=(
773                     f"decorator '{decorator}' uses python3.9+ syntax"
774                     "but is detected as python<=3.8"
775                     # f"The full node is\n{node!r}"
776                 ),
777             )
778
779     def test_get_features_used(self) -> None:
780         node = black.lib2to3_parse("def f(*, arg): ...\n")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("def f(*, arg,): ...\n")
783         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
784         node = black.lib2to3_parse("f(*arg,)\n")
785         self.assertEqual(
786             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
787         )
788         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
789         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
790         node = black.lib2to3_parse("123_456\n")
791         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
792         node = black.lib2to3_parse("123456\n")
793         self.assertEqual(black.get_features_used(node), set())
794         source, expected = read_data("function")
795         node = black.lib2to3_parse(source)
796         expected_features = {
797             Feature.TRAILING_COMMA_IN_CALL,
798             Feature.TRAILING_COMMA_IN_DEF,
799             Feature.F_STRINGS,
800         }
801         self.assertEqual(black.get_features_used(node), expected_features)
802         node = black.lib2to3_parse(expected)
803         self.assertEqual(black.get_features_used(node), expected_features)
804         source, expected = read_data("expression")
805         node = black.lib2to3_parse(source)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse(expected)
808         self.assertEqual(black.get_features_used(node), set())
809         node = black.lib2to3_parse("lambda a, /, b: ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(a, /, b): ...")
812         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
813
814     def test_get_future_imports(self) -> None:
815         node = black.lib2to3_parse("\n")
816         self.assertEqual(set(), black.get_future_imports(node))
817         node = black.lib2to3_parse("from __future__ import black\n")
818         self.assertEqual({"black"}, black.get_future_imports(node))
819         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
820         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
821         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
822         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
823         node = black.lib2to3_parse(
824             "from __future__ import multiple\nfrom __future__ import imports\n"
825         )
826         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
827         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
828         self.assertEqual({"black"}, black.get_future_imports(node))
829         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
830         self.assertEqual({"black"}, black.get_future_imports(node))
831         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
832         self.assertEqual(set(), black.get_future_imports(node))
833         node = black.lib2to3_parse("from some.module import black\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse(
836             "from __future__ import unicode_literals as _unicode_literals"
837         )
838         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
839         node = black.lib2to3_parse(
840             "from __future__ import unicode_literals as _lol, print"
841         )
842         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
843
844     def test_debug_visitor(self) -> None:
845         source, _ = read_data("debug_visitor.py")
846         expected, _ = read_data("debug_visitor.out")
847         out_lines = []
848         err_lines = []
849
850         def out(msg: str, **kwargs: Any) -> None:
851             out_lines.append(msg)
852
853         def err(msg: str, **kwargs: Any) -> None:
854             err_lines.append(msg)
855
856         with patch("black.debug.out", out):
857             DebugVisitor.show(source)
858         actual = "\n".join(out_lines) + "\n"
859         log_name = ""
860         if expected != actual:
861             log_name = black.dump_to_file(*out_lines)
862         self.assertEqual(
863             expected,
864             actual,
865             f"AST print out is different. Actual version dumped to {log_name}",
866         )
867
868     def test_format_file_contents(self) -> None:
869         empty = ""
870         mode = DEFAULT_MODE
871         with self.assertRaises(black.NothingChanged):
872             black.format_file_contents(empty, mode=mode, fast=False)
873         just_nl = "\n"
874         with self.assertRaises(black.NothingChanged):
875             black.format_file_contents(just_nl, mode=mode, fast=False)
876         same = "j = [1, 2, 3]\n"
877         with self.assertRaises(black.NothingChanged):
878             black.format_file_contents(same, mode=mode, fast=False)
879         different = "j = [1,2,3]"
880         expected = same
881         actual = black.format_file_contents(different, mode=mode, fast=False)
882         self.assertEqual(expected, actual)
883         invalid = "return if you can"
884         with self.assertRaises(black.InvalidInput) as e:
885             black.format_file_contents(invalid, mode=mode, fast=False)
886         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
887
888     def test_endmarker(self) -> None:
889         n = black.lib2to3_parse("\n")
890         self.assertEqual(n.type, black.syms.file_input)
891         self.assertEqual(len(n.children), 1)
892         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
893
894     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
895     def test_assertFormatEqual(self) -> None:
896         out_lines = []
897         err_lines = []
898
899         def out(msg: str, **kwargs: Any) -> None:
900             out_lines.append(msg)
901
902         def err(msg: str, **kwargs: Any) -> None:
903             err_lines.append(msg)
904
905         with patch("black.output._out", out), patch("black.output._err", err):
906             with self.assertRaises(AssertionError):
907                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
908
909         out_str = "".join(out_lines)
910         self.assertTrue("Expected tree:" in out_str)
911         self.assertTrue("Actual tree:" in out_str)
912         self.assertEqual("".join(err_lines), "")
913
914     @event_loop()
915     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
916     def test_works_in_mono_process_only_environment(self) -> None:
917         with cache_dir() as workspace:
918             for f in [
919                 (workspace / "one.py").resolve(),
920                 (workspace / "two.py").resolve(),
921             ]:
922                 f.write_text('print("hello")\n')
923             self.invokeBlack([str(workspace)])
924
925     @event_loop()
926     def test_check_diff_use_together(self) -> None:
927         with cache_dir():
928             # Files which will be reformatted.
929             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
930             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
931             # Files which will not be reformatted.
932             src2 = (THIS_DIR / "data" / "composition.py").resolve()
933             self.invokeBlack([str(src2), "--diff", "--check"])
934             # Multi file command.
935             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
936
937     def test_no_files(self) -> None:
938         with cache_dir():
939             # Without an argument, black exits with error code 0.
940             self.invokeBlack([])
941
942     def test_broken_symlink(self) -> None:
943         with cache_dir() as workspace:
944             symlink = workspace / "broken_link.py"
945             try:
946                 symlink.symlink_to("nonexistent.py")
947             except OSError as e:
948                 self.skipTest(f"Can't create symlinks: {e}")
949             self.invokeBlack([str(workspace.resolve())])
950
951     def test_single_file_force_pyi(self) -> None:
952         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
953         contents, expected = read_data("force_pyi")
954         with cache_dir() as workspace:
955             path = (workspace / "file.py").resolve()
956             with open(path, "w") as fh:
957                 fh.write(contents)
958             self.invokeBlack([str(path), "--pyi"])
959             with open(path, "r") as fh:
960                 actual = fh.read()
961             # verify cache with --pyi is separate
962             pyi_cache = black.read_cache(pyi_mode)
963             self.assertIn(str(path), pyi_cache)
964             normal_cache = black.read_cache(DEFAULT_MODE)
965             self.assertNotIn(str(path), normal_cache)
966         self.assertFormatEqual(expected, actual)
967         black.assert_equivalent(contents, actual)
968         black.assert_stable(contents, actual, pyi_mode)
969
970     @event_loop()
971     def test_multi_file_force_pyi(self) -> None:
972         reg_mode = DEFAULT_MODE
973         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
974         contents, expected = read_data("force_pyi")
975         with cache_dir() as workspace:
976             paths = [
977                 (workspace / "file1.py").resolve(),
978                 (workspace / "file2.py").resolve(),
979             ]
980             for path in paths:
981                 with open(path, "w") as fh:
982                     fh.write(contents)
983             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
984             for path in paths:
985                 with open(path, "r") as fh:
986                     actual = fh.read()
987                 self.assertEqual(actual, expected)
988             # verify cache with --pyi is separate
989             pyi_cache = black.read_cache(pyi_mode)
990             normal_cache = black.read_cache(reg_mode)
991             for path in paths:
992                 self.assertIn(str(path), pyi_cache)
993                 self.assertNotIn(str(path), normal_cache)
994
995     def test_pipe_force_pyi(self) -> None:
996         source, expected = read_data("force_pyi")
997         result = CliRunner().invoke(
998             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
999         )
1000         self.assertEqual(result.exit_code, 0)
1001         actual = result.output
1002         self.assertFormatEqual(actual, expected)
1003
1004     def test_single_file_force_py36(self) -> None:
1005         reg_mode = DEFAULT_MODE
1006         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1007         source, expected = read_data("force_py36")
1008         with cache_dir() as workspace:
1009             path = (workspace / "file.py").resolve()
1010             with open(path, "w") as fh:
1011                 fh.write(source)
1012             self.invokeBlack([str(path), *PY36_ARGS])
1013             with open(path, "r") as fh:
1014                 actual = fh.read()
1015             # verify cache with --target-version is separate
1016             py36_cache = black.read_cache(py36_mode)
1017             self.assertIn(str(path), py36_cache)
1018             normal_cache = black.read_cache(reg_mode)
1019             self.assertNotIn(str(path), normal_cache)
1020         self.assertEqual(actual, expected)
1021
1022     @event_loop()
1023     def test_multi_file_force_py36(self) -> None:
1024         reg_mode = DEFAULT_MODE
1025         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1026         source, expected = read_data("force_py36")
1027         with cache_dir() as workspace:
1028             paths = [
1029                 (workspace / "file1.py").resolve(),
1030                 (workspace / "file2.py").resolve(),
1031             ]
1032             for path in paths:
1033                 with open(path, "w") as fh:
1034                     fh.write(source)
1035             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1036             for path in paths:
1037                 with open(path, "r") as fh:
1038                     actual = fh.read()
1039                 self.assertEqual(actual, expected)
1040             # verify cache with --target-version is separate
1041             pyi_cache = black.read_cache(py36_mode)
1042             normal_cache = black.read_cache(reg_mode)
1043             for path in paths:
1044                 self.assertIn(str(path), pyi_cache)
1045                 self.assertNotIn(str(path), normal_cache)
1046
1047     def test_pipe_force_py36(self) -> None:
1048         source, expected = read_data("force_py36")
1049         result = CliRunner().invoke(
1050             black.main,
1051             ["-", "-q", "--target-version=py36"],
1052             input=BytesIO(source.encode("utf8")),
1053         )
1054         self.assertEqual(result.exit_code, 0)
1055         actual = result.output
1056         self.assertFormatEqual(actual, expected)
1057
1058     def test_reformat_one_with_stdin(self) -> None:
1059         with patch(
1060             "black.format_stdin_to_stdout",
1061             return_value=lambda *args, **kwargs: black.Changed.YES,
1062         ) as fsts:
1063             report = MagicMock()
1064             path = Path("-")
1065             black.reformat_one(
1066                 path,
1067                 fast=True,
1068                 write_back=black.WriteBack.YES,
1069                 mode=DEFAULT_MODE,
1070                 report=report,
1071             )
1072             fsts.assert_called_once()
1073             report.done.assert_called_with(path, black.Changed.YES)
1074
1075     def test_reformat_one_with_stdin_filename(self) -> None:
1076         with patch(
1077             "black.format_stdin_to_stdout",
1078             return_value=lambda *args, **kwargs: black.Changed.YES,
1079         ) as fsts:
1080             report = MagicMock()
1081             p = "foo.py"
1082             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1083             expected = Path(p)
1084             black.reformat_one(
1085                 path,
1086                 fast=True,
1087                 write_back=black.WriteBack.YES,
1088                 mode=DEFAULT_MODE,
1089                 report=report,
1090             )
1091             fsts.assert_called_once_with(
1092                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1093             )
1094             # __BLACK_STDIN_FILENAME__ should have been stripped
1095             report.done.assert_called_with(expected, black.Changed.YES)
1096
1097     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1098         with patch(
1099             "black.format_stdin_to_stdout",
1100             return_value=lambda *args, **kwargs: black.Changed.YES,
1101         ) as fsts:
1102             report = MagicMock()
1103             p = "foo.pyi"
1104             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1105             expected = Path(p)
1106             black.reformat_one(
1107                 path,
1108                 fast=True,
1109                 write_back=black.WriteBack.YES,
1110                 mode=DEFAULT_MODE,
1111                 report=report,
1112             )
1113             fsts.assert_called_once_with(
1114                 fast=True,
1115                 write_back=black.WriteBack.YES,
1116                 mode=replace(DEFAULT_MODE, is_pyi=True),
1117             )
1118             # __BLACK_STDIN_FILENAME__ should have been stripped
1119             report.done.assert_called_with(expected, black.Changed.YES)
1120
1121     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1122         with patch(
1123             "black.format_stdin_to_stdout",
1124             return_value=lambda *args, **kwargs: black.Changed.YES,
1125         ) as fsts:
1126             report = MagicMock()
1127             p = "foo.ipynb"
1128             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1129             expected = Path(p)
1130             black.reformat_one(
1131                 path,
1132                 fast=True,
1133                 write_back=black.WriteBack.YES,
1134                 mode=DEFAULT_MODE,
1135                 report=report,
1136             )
1137             fsts.assert_called_once_with(
1138                 fast=True,
1139                 write_back=black.WriteBack.YES,
1140                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1141             )
1142             # __BLACK_STDIN_FILENAME__ should have been stripped
1143             report.done.assert_called_with(expected, black.Changed.YES)
1144
1145     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1146         with patch(
1147             "black.format_stdin_to_stdout",
1148             return_value=lambda *args, **kwargs: black.Changed.YES,
1149         ) as fsts:
1150             report = MagicMock()
1151             # Even with an existing file, since we are forcing stdin, black
1152             # should output to stdout and not modify the file inplace
1153             p = Path(str(THIS_DIR / "data/collections.py"))
1154             # Make sure is_file actually returns True
1155             self.assertTrue(p.is_file())
1156             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1157             expected = Path(p)
1158             black.reformat_one(
1159                 path,
1160                 fast=True,
1161                 write_back=black.WriteBack.YES,
1162                 mode=DEFAULT_MODE,
1163                 report=report,
1164             )
1165             fsts.assert_called_once()
1166             # __BLACK_STDIN_FILENAME__ should have been stripped
1167             report.done.assert_called_with(expected, black.Changed.YES)
1168
1169     def test_reformat_one_with_stdin_empty(self) -> None:
1170         output = io.StringIO()
1171         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1172             try:
1173                 black.format_stdin_to_stdout(
1174                     fast=True,
1175                     content="",
1176                     write_back=black.WriteBack.YES,
1177                     mode=DEFAULT_MODE,
1178                 )
1179             except io.UnsupportedOperation:
1180                 pass  # StringIO does not support detach
1181             assert output.getvalue() == ""
1182
1183     def test_invalid_cli_regex(self) -> None:
1184         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1185             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1186
1187     def test_required_version_matches_version(self) -> None:
1188         self.invokeBlack(
1189             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1190         )
1191
1192     def test_required_version_does_not_match_version(self) -> None:
1193         self.invokeBlack(
1194             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1195         )
1196
1197     def test_preserves_line_endings(self) -> None:
1198         with TemporaryDirectory() as workspace:
1199             test_file = Path(workspace) / "test.py"
1200             for nl in ["\n", "\r\n"]:
1201                 contents = nl.join(["def f(  ):", "    pass"])
1202                 test_file.write_bytes(contents.encode())
1203                 ff(test_file, write_back=black.WriteBack.YES)
1204                 updated_contents: bytes = test_file.read_bytes()
1205                 self.assertIn(nl.encode(), updated_contents)
1206                 if nl == "\n":
1207                     self.assertNotIn(b"\r\n", updated_contents)
1208
1209     def test_preserves_line_endings_via_stdin(self) -> None:
1210         for nl in ["\n", "\r\n"]:
1211             contents = nl.join(["def f(  ):", "    pass"])
1212             runner = BlackRunner()
1213             result = runner.invoke(
1214                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1215             )
1216             self.assertEqual(result.exit_code, 0)
1217             output = result.stdout_bytes
1218             self.assertIn(nl.encode("utf8"), output)
1219             if nl == "\n":
1220                 self.assertNotIn(b"\r\n", output)
1221
1222     def test_assert_equivalent_different_asts(self) -> None:
1223         with self.assertRaises(AssertionError):
1224             black.assert_equivalent("{}", "None")
1225
1226     def test_shhh_click(self) -> None:
1227         try:
1228             from click import _unicodefun
1229         except ModuleNotFoundError:
1230             self.skipTest("Incompatible Click version")
1231         if not hasattr(_unicodefun, "_verify_python3_env"):
1232             self.skipTest("Incompatible Click version")
1233         # First, let's see if Click is crashing with a preferred ASCII charset.
1234         with patch("locale.getpreferredencoding") as gpe:
1235             gpe.return_value = "ASCII"
1236             with self.assertRaises(RuntimeError):
1237                 _unicodefun._verify_python3_env()  # type: ignore
1238         # Now, let's silence Click...
1239         black.patch_click()
1240         # ...and confirm it's silent.
1241         with patch("locale.getpreferredencoding") as gpe:
1242             gpe.return_value = "ASCII"
1243             try:
1244                 _unicodefun._verify_python3_env()  # type: ignore
1245             except RuntimeError as re:
1246                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1247
1248     def test_root_logger_not_used_directly(self) -> None:
1249         def fail(*args: Any, **kwargs: Any) -> None:
1250             self.fail("Record created with root logger")
1251
1252         with patch.multiple(
1253             logging.root,
1254             debug=fail,
1255             info=fail,
1256             warning=fail,
1257             error=fail,
1258             critical=fail,
1259             log=fail,
1260         ):
1261             ff(THIS_DIR / "util.py")
1262
1263     def test_invalid_config_return_code(self) -> None:
1264         tmp_file = Path(black.dump_to_file())
1265         try:
1266             tmp_config = Path(black.dump_to_file())
1267             tmp_config.unlink()
1268             args = ["--config", str(tmp_config), str(tmp_file)]
1269             self.invokeBlack(args, exit_code=2, ignore_config=False)
1270         finally:
1271             tmp_file.unlink()
1272
1273     def test_parse_pyproject_toml(self) -> None:
1274         test_toml_file = THIS_DIR / "test.toml"
1275         config = black.parse_pyproject_toml(str(test_toml_file))
1276         self.assertEqual(config["verbose"], 1)
1277         self.assertEqual(config["check"], "no")
1278         self.assertEqual(config["diff"], "y")
1279         self.assertEqual(config["color"], True)
1280         self.assertEqual(config["line_length"], 79)
1281         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1282         self.assertEqual(config["exclude"], r"\.pyi?$")
1283         self.assertEqual(config["include"], r"\.py?$")
1284
1285     def test_read_pyproject_toml(self) -> None:
1286         test_toml_file = THIS_DIR / "test.toml"
1287         fake_ctx = FakeContext()
1288         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1289         config = fake_ctx.default_map
1290         self.assertEqual(config["verbose"], "1")
1291         self.assertEqual(config["check"], "no")
1292         self.assertEqual(config["diff"], "y")
1293         self.assertEqual(config["color"], "True")
1294         self.assertEqual(config["line_length"], "79")
1295         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1296         self.assertEqual(config["exclude"], r"\.pyi?$")
1297         self.assertEqual(config["include"], r"\.py?$")
1298
1299     def test_find_project_root(self) -> None:
1300         with TemporaryDirectory() as workspace:
1301             root = Path(workspace)
1302             test_dir = root / "test"
1303             test_dir.mkdir()
1304
1305             src_dir = root / "src"
1306             src_dir.mkdir()
1307
1308             root_pyproject = root / "pyproject.toml"
1309             root_pyproject.touch()
1310             src_pyproject = src_dir / "pyproject.toml"
1311             src_pyproject.touch()
1312             src_python = src_dir / "foo.py"
1313             src_python.touch()
1314
1315             self.assertEqual(
1316                 black.find_project_root((src_dir, test_dir)), root.resolve()
1317             )
1318             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1319             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1320
1321     @patch(
1322         "black.files.find_user_pyproject_toml",
1323         black.files.find_user_pyproject_toml.__wrapped__,
1324     )
1325     def test_find_user_pyproject_toml_linux(self) -> None:
1326         if system() == "Windows":
1327             return
1328
1329         # Test if XDG_CONFIG_HOME is checked
1330         with TemporaryDirectory() as workspace:
1331             tmp_user_config = Path(workspace) / "black"
1332             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1333                 self.assertEqual(
1334                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1335                 )
1336
1337         # Test fallback for XDG_CONFIG_HOME
1338         with patch.dict("os.environ"):
1339             os.environ.pop("XDG_CONFIG_HOME", None)
1340             fallback_user_config = Path("~/.config").expanduser() / "black"
1341             self.assertEqual(
1342                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1343             )
1344
1345     def test_find_user_pyproject_toml_windows(self) -> None:
1346         if system() != "Windows":
1347             return
1348
1349         user_config_path = Path.home() / ".black"
1350         self.assertEqual(
1351             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1352         )
1353
1354     def test_bpo_33660_workaround(self) -> None:
1355         if system() == "Windows":
1356             return
1357
1358         # https://bugs.python.org/issue33660
1359         root = Path("/")
1360         with change_directory(root):
1361             path = Path("workspace") / "project"
1362             report = black.Report(verbose=True)
1363             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1364             self.assertEqual(normalized_path, "workspace/project")
1365
1366     def test_newline_comment_interaction(self) -> None:
1367         source = "class A:\\\r\n# type: ignore\n pass\n"
1368         output = black.format_str(source, mode=DEFAULT_MODE)
1369         black.assert_stable(source, output, mode=DEFAULT_MODE)
1370
1371     def test_bpo_2142_workaround(self) -> None:
1372
1373         # https://bugs.python.org/issue2142
1374
1375         source, _ = read_data("missing_final_newline.py")
1376         # read_data adds a trailing newline
1377         source = source.rstrip()
1378         expected, _ = read_data("missing_final_newline.diff")
1379         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1380         diff_header = re.compile(
1381             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1382             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1383         )
1384         try:
1385             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1386             self.assertEqual(result.exit_code, 0)
1387         finally:
1388             os.unlink(tmp_file)
1389         actual = result.output
1390         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1391         self.assertEqual(actual, expected)
1392
1393     @pytest.mark.python2
1394     def test_docstring_reformat_for_py27(self) -> None:
1395         """
1396         Check that stripping trailing whitespace from Python 2 docstrings
1397         doesn't trigger a "not equivalent to source" error
1398         """
1399         source = (
1400             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1401         )
1402         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1403
1404         result = CliRunner().invoke(
1405             black.main,
1406             ["-", "-q", "--target-version=py27"],
1407             input=BytesIO(source),
1408         )
1409
1410         self.assertEqual(result.exit_code, 0)
1411         actual = result.output
1412         self.assertFormatEqual(actual, expected)
1413
1414     @staticmethod
1415     def compare_results(
1416         result: click.testing.Result, expected_value: str, expected_exit_code: int
1417     ) -> None:
1418         """Helper method to test the value and exit code of a click Result."""
1419         assert (
1420             result.output == expected_value
1421         ), "The output did not match the expected value."
1422         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1423
1424     def test_code_option(self) -> None:
1425         """Test the code option with no changes."""
1426         code = 'print("Hello world")\n'
1427         args = ["--code", code]
1428         result = CliRunner().invoke(black.main, args)
1429
1430         self.compare_results(result, code, 0)
1431
1432     def test_code_option_changed(self) -> None:
1433         """Test the code option when changes are required."""
1434         code = "print('hello world')"
1435         formatted = black.format_str(code, mode=DEFAULT_MODE)
1436
1437         args = ["--code", code]
1438         result = CliRunner().invoke(black.main, args)
1439
1440         self.compare_results(result, formatted, 0)
1441
1442     def test_code_option_check(self) -> None:
1443         """Test the code option when check is passed."""
1444         args = ["--check", "--code", 'print("Hello world")\n']
1445         result = CliRunner().invoke(black.main, args)
1446         self.compare_results(result, "", 0)
1447
1448     def test_code_option_check_changed(self) -> None:
1449         """Test the code option when changes are required, and check is passed."""
1450         args = ["--check", "--code", "print('hello world')"]
1451         result = CliRunner().invoke(black.main, args)
1452         self.compare_results(result, "", 1)
1453
1454     def test_code_option_diff(self) -> None:
1455         """Test the code option when diff is passed."""
1456         code = "print('hello world')"
1457         formatted = black.format_str(code, mode=DEFAULT_MODE)
1458         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1459
1460         args = ["--diff", "--code", code]
1461         result = CliRunner().invoke(black.main, args)
1462
1463         # Remove time from diff
1464         output = DIFF_TIME.sub("", result.output)
1465
1466         assert output == result_diff, "The output did not match the expected value."
1467         assert result.exit_code == 0, "The exit code is incorrect."
1468
1469     def test_code_option_color_diff(self) -> None:
1470         """Test the code option when color and diff are passed."""
1471         code = "print('hello world')"
1472         formatted = black.format_str(code, mode=DEFAULT_MODE)
1473
1474         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1475         result_diff = color_diff(result_diff)
1476
1477         args = ["--diff", "--color", "--code", code]
1478         result = CliRunner().invoke(black.main, args)
1479
1480         # Remove time from diff
1481         output = DIFF_TIME.sub("", result.output)
1482
1483         assert output == result_diff, "The output did not match the expected value."
1484         assert result.exit_code == 0, "The exit code is incorrect."
1485
1486     def test_code_option_safe(self) -> None:
1487         """Test that the code option throws an error when the sanity checks fail."""
1488         # Patch black.assert_equivalent to ensure the sanity checks fail
1489         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1490             code = 'print("Hello world")'
1491             error_msg = f"{code}\nerror: cannot format <string>: \n"
1492
1493             args = ["--safe", "--code", code]
1494             result = CliRunner().invoke(black.main, args)
1495
1496             self.compare_results(result, error_msg, 123)
1497
1498     def test_code_option_fast(self) -> None:
1499         """Test that the code option ignores errors when the sanity checks fail."""
1500         # Patch black.assert_equivalent to ensure the sanity checks fail
1501         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1502             code = 'print("Hello world")'
1503             formatted = black.format_str(code, mode=DEFAULT_MODE)
1504
1505             args = ["--fast", "--code", code]
1506             result = CliRunner().invoke(black.main, args)
1507
1508             self.compare_results(result, formatted, 0)
1509
1510     def test_code_option_config(self) -> None:
1511         """
1512         Test that the code option finds the pyproject.toml in the current directory.
1513         """
1514         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1515             args = ["--code", "print"]
1516             # This is the only directory known to contain a pyproject.toml
1517             with change_directory(PROJECT_ROOT):
1518                 CliRunner().invoke(black.main, args)
1519                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1520
1521             assert (
1522                 len(parse.mock_calls) >= 1
1523             ), "Expected config parse to be called with the current directory."
1524
1525             _, call_args, _ = parse.mock_calls[0]
1526             assert (
1527                 call_args[0].lower() == str(pyproject_path).lower()
1528             ), "Incorrect config loaded."
1529
1530     def test_code_option_parent_config(self) -> None:
1531         """
1532         Test that the code option finds the pyproject.toml in the parent directory.
1533         """
1534         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1535             with change_directory(THIS_DIR):
1536                 args = ["--code", "print"]
1537                 CliRunner().invoke(black.main, args)
1538
1539                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1540                 assert (
1541                     len(parse.mock_calls) >= 1
1542                 ), "Expected config parse to be called with the current directory."
1543
1544                 _, call_args, _ = parse.mock_calls[0]
1545                 assert (
1546                     call_args[0].lower() == str(pyproject_path).lower()
1547                 ), "Incorrect config loaded."
1548
1549
1550 class TestCaching:
1551     def test_cache_broken_file(self) -> None:
1552         mode = DEFAULT_MODE
1553         with cache_dir() as workspace:
1554             cache_file = get_cache_file(mode)
1555             cache_file.write_text("this is not a pickle")
1556             assert black.read_cache(mode) == {}
1557             src = (workspace / "test.py").resolve()
1558             src.write_text("print('hello')")
1559             invokeBlack([str(src)])
1560             cache = black.read_cache(mode)
1561             assert str(src) in cache
1562
1563     def test_cache_single_file_already_cached(self) -> None:
1564         mode = DEFAULT_MODE
1565         with cache_dir() as workspace:
1566             src = (workspace / "test.py").resolve()
1567             src.write_text("print('hello')")
1568             black.write_cache({}, [src], mode)
1569             invokeBlack([str(src)])
1570             assert src.read_text() == "print('hello')"
1571
1572     @event_loop()
1573     def test_cache_multiple_files(self) -> None:
1574         mode = DEFAULT_MODE
1575         with cache_dir() as workspace, patch(
1576             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1577         ):
1578             one = (workspace / "one.py").resolve()
1579             with one.open("w") as fobj:
1580                 fobj.write("print('hello')")
1581             two = (workspace / "two.py").resolve()
1582             with two.open("w") as fobj:
1583                 fobj.write("print('hello')")
1584             black.write_cache({}, [one], mode)
1585             invokeBlack([str(workspace)])
1586             with one.open("r") as fobj:
1587                 assert fobj.read() == "print('hello')"
1588             with two.open("r") as fobj:
1589                 assert fobj.read() == 'print("hello")\n'
1590             cache = black.read_cache(mode)
1591             assert str(one) in cache
1592             assert str(two) in cache
1593
1594     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1595     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1596         mode = DEFAULT_MODE
1597         with cache_dir() as workspace:
1598             src = (workspace / "test.py").resolve()
1599             with src.open("w") as fobj:
1600                 fobj.write("print('hello')")
1601             with patch("black.read_cache") as read_cache, patch(
1602                 "black.write_cache"
1603             ) as write_cache:
1604                 cmd = [str(src), "--diff"]
1605                 if color:
1606                     cmd.append("--color")
1607                 invokeBlack(cmd)
1608                 cache_file = get_cache_file(mode)
1609                 assert cache_file.exists() is False
1610                 write_cache.assert_not_called()
1611                 read_cache.assert_not_called()
1612
1613     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1614     @event_loop()
1615     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1616         with cache_dir() as workspace:
1617             for tag in range(0, 4):
1618                 src = (workspace / f"test{tag}.py").resolve()
1619                 with src.open("w") as fobj:
1620                     fobj.write("print('hello')")
1621             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1622                 cmd = ["--diff", str(workspace)]
1623                 if color:
1624                     cmd.append("--color")
1625                 invokeBlack(cmd, exit_code=0)
1626                 # this isn't quite doing what we want, but if it _isn't_
1627                 # called then we cannot be using the lock it provides
1628                 mgr.assert_called()
1629
1630     def test_no_cache_when_stdin(self) -> None:
1631         mode = DEFAULT_MODE
1632         with cache_dir():
1633             result = CliRunner().invoke(
1634                 black.main, ["-"], input=BytesIO(b"print('hello')")
1635             )
1636             assert not result.exit_code
1637             cache_file = get_cache_file(mode)
1638             assert not cache_file.exists()
1639
1640     def test_read_cache_no_cachefile(self) -> None:
1641         mode = DEFAULT_MODE
1642         with cache_dir():
1643             assert black.read_cache(mode) == {}
1644
1645     def test_write_cache_read_cache(self) -> None:
1646         mode = DEFAULT_MODE
1647         with cache_dir() as workspace:
1648             src = (workspace / "test.py").resolve()
1649             src.touch()
1650             black.write_cache({}, [src], mode)
1651             cache = black.read_cache(mode)
1652             assert str(src) in cache
1653             assert cache[str(src)] == black.get_cache_info(src)
1654
1655     def test_filter_cached(self) -> None:
1656         with TemporaryDirectory() as workspace:
1657             path = Path(workspace)
1658             uncached = (path / "uncached").resolve()
1659             cached = (path / "cached").resolve()
1660             cached_but_changed = (path / "changed").resolve()
1661             uncached.touch()
1662             cached.touch()
1663             cached_but_changed.touch()
1664             cache = {
1665                 str(cached): black.get_cache_info(cached),
1666                 str(cached_but_changed): (0.0, 0),
1667             }
1668             todo, done = black.filter_cached(
1669                 cache, {uncached, cached, cached_but_changed}
1670             )
1671             assert todo == {uncached, cached_but_changed}
1672             assert done == {cached}
1673
1674     def test_write_cache_creates_directory_if_needed(self) -> None:
1675         mode = DEFAULT_MODE
1676         with cache_dir(exists=False) as workspace:
1677             assert not workspace.exists()
1678             black.write_cache({}, [], mode)
1679             assert workspace.exists()
1680
1681     @event_loop()
1682     def test_failed_formatting_does_not_get_cached(self) -> None:
1683         mode = DEFAULT_MODE
1684         with cache_dir() as workspace, patch(
1685             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1686         ):
1687             failing = (workspace / "failing.py").resolve()
1688             with failing.open("w") as fobj:
1689                 fobj.write("not actually python")
1690             clean = (workspace / "clean.py").resolve()
1691             with clean.open("w") as fobj:
1692                 fobj.write('print("hello")\n')
1693             invokeBlack([str(workspace)], exit_code=123)
1694             cache = black.read_cache(mode)
1695             assert str(failing) not in cache
1696             assert str(clean) in cache
1697
1698     def test_write_cache_write_fail(self) -> None:
1699         mode = DEFAULT_MODE
1700         with cache_dir(), patch.object(Path, "open") as mock:
1701             mock.side_effect = OSError
1702             black.write_cache({}, [], mode)
1703
1704     def test_read_cache_line_lengths(self) -> None:
1705         mode = DEFAULT_MODE
1706         short_mode = replace(DEFAULT_MODE, line_length=1)
1707         with cache_dir() as workspace:
1708             path = (workspace / "file.py").resolve()
1709             path.touch()
1710             black.write_cache({}, [path], mode)
1711             one = black.read_cache(mode)
1712             assert str(path) in one
1713             two = black.read_cache(short_mode)
1714             assert str(path) not in two
1715
1716
1717 def assert_collected_sources(
1718     src: Sequence[Union[str, Path]],
1719     expected: Sequence[Union[str, Path]],
1720     *,
1721     exclude: Optional[str] = None,
1722     include: Optional[str] = None,
1723     extend_exclude: Optional[str] = None,
1724     force_exclude: Optional[str] = None,
1725     stdin_filename: Optional[str] = None,
1726 ) -> None:
1727     gs_src = tuple(str(Path(s)) for s in src)
1728     gs_expected = [Path(s) for s in expected]
1729     gs_exclude = None if exclude is None else compile_pattern(exclude)
1730     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1731     gs_extend_exclude = (
1732         None if extend_exclude is None else compile_pattern(extend_exclude)
1733     )
1734     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1735     collected = black.get_sources(
1736         ctx=FakeContext(),
1737         src=gs_src,
1738         quiet=False,
1739         verbose=False,
1740         include=gs_include,
1741         exclude=gs_exclude,
1742         extend_exclude=gs_extend_exclude,
1743         force_exclude=gs_force_exclude,
1744         report=black.Report(),
1745         stdin_filename=stdin_filename,
1746     )
1747     assert sorted(list(collected)) == sorted(gs_expected)
1748
1749
1750 class TestFileCollection:
1751     def test_include_exclude(self) -> None:
1752         path = THIS_DIR / "data" / "include_exclude_tests"
1753         src = [path]
1754         expected = [
1755             Path(path / "b/dont_exclude/a.py"),
1756             Path(path / "b/dont_exclude/a.pyi"),
1757         ]
1758         assert_collected_sources(
1759             src,
1760             expected,
1761             include=r"\.pyi?$",
1762             exclude=r"/exclude/|/\.definitely_exclude/",
1763         )
1764
1765     def test_gitignore_used_as_default(self) -> None:
1766         base = Path(DATA_DIR / "include_exclude_tests")
1767         expected = [
1768             base / "b/.definitely_exclude/a.py",
1769             base / "b/.definitely_exclude/a.pyi",
1770         ]
1771         src = [base / "b/"]
1772         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1773
1774     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1775     def test_exclude_for_issue_1572(self) -> None:
1776         # Exclude shouldn't touch files that were explicitly given to Black through the
1777         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1778         # https://github.com/psf/black/issues/1572
1779         path = DATA_DIR / "include_exclude_tests"
1780         src = [path / "b/exclude/a.py"]
1781         expected = [path / "b/exclude/a.py"]
1782         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1783
1784     def test_gitignore_exclude(self) -> None:
1785         path = THIS_DIR / "data" / "include_exclude_tests"
1786         include = re.compile(r"\.pyi?$")
1787         exclude = re.compile(r"")
1788         report = black.Report()
1789         gitignore = PathSpec.from_lines(
1790             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1791         )
1792         sources: List[Path] = []
1793         expected = [
1794             Path(path / "b/dont_exclude/a.py"),
1795             Path(path / "b/dont_exclude/a.pyi"),
1796         ]
1797         this_abs = THIS_DIR.resolve()
1798         sources.extend(
1799             black.gen_python_files(
1800                 path.iterdir(),
1801                 this_abs,
1802                 include,
1803                 exclude,
1804                 None,
1805                 None,
1806                 report,
1807                 gitignore,
1808                 verbose=False,
1809                 quiet=False,
1810             )
1811         )
1812         assert sorted(expected) == sorted(sources)
1813
1814     def test_nested_gitignore(self) -> None:
1815         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1816         include = re.compile(r"\.pyi?$")
1817         exclude = re.compile(r"")
1818         root_gitignore = black.files.get_gitignore(path)
1819         report = black.Report()
1820         expected: List[Path] = [
1821             Path(path / "x.py"),
1822             Path(path / "root/b.py"),
1823             Path(path / "root/c.py"),
1824             Path(path / "root/child/c.py"),
1825         ]
1826         this_abs = THIS_DIR.resolve()
1827         sources = list(
1828             black.gen_python_files(
1829                 path.iterdir(),
1830                 this_abs,
1831                 include,
1832                 exclude,
1833                 None,
1834                 None,
1835                 report,
1836                 root_gitignore,
1837                 verbose=False,
1838                 quiet=False,
1839             )
1840         )
1841         assert sorted(expected) == sorted(sources)
1842
1843     def test_invalid_gitignore(self) -> None:
1844         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1845         empty_config = path / "pyproject.toml"
1846         result = BlackRunner().invoke(
1847             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1848         )
1849         assert result.exit_code == 1
1850         assert result.stderr_bytes is not None
1851
1852         gitignore = path / ".gitignore"
1853         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1854
1855     def test_invalid_nested_gitignore(self) -> None:
1856         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1857         empty_config = path / "pyproject.toml"
1858         result = BlackRunner().invoke(
1859             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1860         )
1861         assert result.exit_code == 1
1862         assert result.stderr_bytes is not None
1863
1864         gitignore = path / "a" / ".gitignore"
1865         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1866
1867     def test_empty_include(self) -> None:
1868         path = DATA_DIR / "include_exclude_tests"
1869         src = [path]
1870         expected = [
1871             Path(path / "b/exclude/a.pie"),
1872             Path(path / "b/exclude/a.py"),
1873             Path(path / "b/exclude/a.pyi"),
1874             Path(path / "b/dont_exclude/a.pie"),
1875             Path(path / "b/dont_exclude/a.py"),
1876             Path(path / "b/dont_exclude/a.pyi"),
1877             Path(path / "b/.definitely_exclude/a.pie"),
1878             Path(path / "b/.definitely_exclude/a.py"),
1879             Path(path / "b/.definitely_exclude/a.pyi"),
1880             Path(path / ".gitignore"),
1881             Path(path / "pyproject.toml"),
1882         ]
1883         # Setting exclude explicitly to an empty string to block .gitignore usage.
1884         assert_collected_sources(src, expected, include="", exclude="")
1885
1886     def test_extend_exclude(self) -> None:
1887         path = DATA_DIR / "include_exclude_tests"
1888         src = [path]
1889         expected = [
1890             Path(path / "b/exclude/a.py"),
1891             Path(path / "b/dont_exclude/a.py"),
1892         ]
1893         assert_collected_sources(
1894             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1895         )
1896
1897     def test_symlink_out_of_root_directory(self) -> None:
1898         path = MagicMock()
1899         root = THIS_DIR.resolve()
1900         child = MagicMock()
1901         include = re.compile(black.DEFAULT_INCLUDES)
1902         exclude = re.compile(black.DEFAULT_EXCLUDES)
1903         report = black.Report()
1904         gitignore = PathSpec.from_lines("gitwildmatch", [])
1905         # `child` should behave like a symlink which resolved path is clearly
1906         # outside of the `root` directory.
1907         path.iterdir.return_value = [child]
1908         child.resolve.return_value = Path("/a/b/c")
1909         child.as_posix.return_value = "/a/b/c"
1910         child.is_symlink.return_value = True
1911         try:
1912             list(
1913                 black.gen_python_files(
1914                     path.iterdir(),
1915                     root,
1916                     include,
1917                     exclude,
1918                     None,
1919                     None,
1920                     report,
1921                     gitignore,
1922                     verbose=False,
1923                     quiet=False,
1924                 )
1925             )
1926         except ValueError as ve:
1927             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1928         path.iterdir.assert_called_once()
1929         child.resolve.assert_called_once()
1930         child.is_symlink.assert_called_once()
1931         # `child` should behave like a strange file which resolved path is clearly
1932         # outside of the `root` directory.
1933         child.is_symlink.return_value = False
1934         with pytest.raises(ValueError):
1935             list(
1936                 black.gen_python_files(
1937                     path.iterdir(),
1938                     root,
1939                     include,
1940                     exclude,
1941                     None,
1942                     None,
1943                     report,
1944                     gitignore,
1945                     verbose=False,
1946                     quiet=False,
1947                 )
1948             )
1949         path.iterdir.assert_called()
1950         assert path.iterdir.call_count == 2
1951         child.resolve.assert_called()
1952         assert child.resolve.call_count == 2
1953         child.is_symlink.assert_called()
1954         assert child.is_symlink.call_count == 2
1955
1956     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1957     def test_get_sources_with_stdin(self) -> None:
1958         src = ["-"]
1959         expected = ["-"]
1960         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1961
1962     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1963     def test_get_sources_with_stdin_filename(self) -> None:
1964         src = ["-"]
1965         stdin_filename = str(THIS_DIR / "data/collections.py")
1966         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
1967         assert_collected_sources(
1968             src,
1969             expected,
1970             exclude=r"/exclude/a\.py",
1971             stdin_filename=stdin_filename,
1972         )
1973
1974     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1975     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1976         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1977         # file being passed directly. This is the same as
1978         # test_exclude_for_issue_1572
1979         path = DATA_DIR / "include_exclude_tests"
1980         src = ["-"]
1981         stdin_filename = str(path / "b/exclude/a.py")
1982         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
1983         assert_collected_sources(
1984             src,
1985             expected,
1986             exclude=r"/exclude/|a\.py",
1987             stdin_filename=stdin_filename,
1988         )
1989
1990     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1991     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1992         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1993         # file being passed directly. This is the same as
1994         # test_exclude_for_issue_1572
1995         src = ["-"]
1996         path = THIS_DIR / "data" / "include_exclude_tests"
1997         stdin_filename = str(path / "b/exclude/a.py")
1998         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
1999         assert_collected_sources(
2000             src,
2001             expected,
2002             extend_exclude=r"/exclude/|a\.py",
2003             stdin_filename=stdin_filename,
2004         )
2005
2006     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2007     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2008         # Force exclude should exclude the file when passing it through
2009         # stdin_filename
2010         path = THIS_DIR / "data" / "include_exclude_tests"
2011         stdin_filename = str(path / "b/exclude/a.py")
2012         assert_collected_sources(
2013             src=["-"],
2014             expected=[],
2015             force_exclude=r"/exclude/|a\.py",
2016             stdin_filename=stdin_filename,
2017         )
2018
2019
2020 with open(black.__file__, "r", encoding="utf-8") as _bf:
2021     black_source_lines = _bf.readlines()
2022
2023
2024 def tracefunc(
2025     frame: types.FrameType, event: str, arg: Any
2026 ) -> Callable[[types.FrameType, str, Any], Any]:
2027     """Show function calls `from black/__init__.py` as they happen.
2028
2029     Register this with `sys.settrace()` in a test you're debugging.
2030     """
2031     if event != "call":
2032         return tracefunc
2033
2034     stack = len(inspect.stack()) - 19
2035     stack *= 2
2036     filename = frame.f_code.co_filename
2037     lineno = frame.f_lineno
2038     func_sig_lineno = lineno - 1
2039     funcname = black_source_lines[func_sig_lineno].strip()
2040     while funcname.startswith("@"):
2041         func_sig_lineno += 1
2042         funcname = black_source_lines[func_sig_lineno].strip()
2043     if "black/__init__.py" in filename:
2044         print(f"{' ' * stack}{lineno}:{funcname}")
2045     return tracefunc