]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix changelog entries in the wrong release (#2825)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103         # Dummy root, since most of the tests don't care about it
104         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
105
106
107 class FakeParameter(click.Parameter):
108     """A fake click Parameter for when calling functions that need it."""
109
110     def __init__(self) -> None:
111         pass
112
113
114 class BlackRunner(CliRunner):
115     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
116
117     def __init__(self) -> None:
118         super().__init__(mix_stderr=False)
119
120
121 def invokeBlack(
122     args: List[str], exit_code: int = 0, ignore_config: bool = True
123 ) -> None:
124     runner = BlackRunner()
125     if ignore_config:
126         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
127     result = runner.invoke(black.main, args, catch_exceptions=False)
128     assert result.stdout_bytes is not None
129     assert result.stderr_bytes is not None
130     msg = (
131         f"Failed with args: {args}\n"
132         f"stdout: {result.stdout_bytes.decode()!r}\n"
133         f"stderr: {result.stderr_bytes.decode()!r}\n"
134         f"exception: {result.exception}"
135     )
136     assert result.exit_code == exit_code, msg
137
138
139 class BlackTestCase(BlackBaseTestCase):
140     invokeBlack = staticmethod(invokeBlack)
141
142     def test_empty_ff(self) -> None:
143         expected = ""
144         tmp_file = Path(black.dump_to_file())
145         try:
146             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
147             with open(tmp_file, encoding="utf8") as f:
148                 actual = f.read()
149         finally:
150             os.unlink(tmp_file)
151         self.assertFormatEqual(expected, actual)
152
153     def test_experimental_string_processing_warns(self) -> None:
154         self.assertWarns(
155             black.mode.Deprecated, black.Mode, experimental_string_processing=True
156         )
157
158     def test_piping(self) -> None:
159         source, expected = read_data("src/black/__init__", data=False)
160         result = BlackRunner().invoke(
161             black.main,
162             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
163             input=BytesIO(source.encode("utf8")),
164         )
165         self.assertEqual(result.exit_code, 0)
166         self.assertFormatEqual(expected, result.output)
167         if source != result.output:
168             black.assert_equivalent(source, result.output)
169             black.assert_stable(source, result.output, DEFAULT_MODE)
170
171     def test_piping_diff(self) -> None:
172         diff_header = re.compile(
173             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
174             r"\+\d\d\d\d"
175         )
176         source, _ = read_data("expression.py")
177         expected, _ = read_data("expression.diff")
178         config = THIS_DIR / "data" / "empty_pyproject.toml"
179         args = [
180             "-",
181             "--fast",
182             f"--line-length={black.DEFAULT_LINE_LENGTH}",
183             "--diff",
184             f"--config={config}",
185         ]
186         result = BlackRunner().invoke(
187             black.main, args, input=BytesIO(source.encode("utf8"))
188         )
189         self.assertEqual(result.exit_code, 0)
190         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
191         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
192         self.assertEqual(expected, actual)
193
194     def test_piping_diff_with_color(self) -> None:
195         source, _ = read_data("expression.py")
196         config = THIS_DIR / "data" / "empty_pyproject.toml"
197         args = [
198             "-",
199             "--fast",
200             f"--line-length={black.DEFAULT_LINE_LENGTH}",
201             "--diff",
202             "--color",
203             f"--config={config}",
204         ]
205         result = BlackRunner().invoke(
206             black.main, args, input=BytesIO(source.encode("utf8"))
207         )
208         actual = result.output
209         # Again, the contents are checked in a different test, so only look for colors.
210         self.assertIn("\033[1m", actual)
211         self.assertIn("\033[36m", actual)
212         self.assertIn("\033[32m", actual)
213         self.assertIn("\033[31m", actual)
214         self.assertIn("\033[0m", actual)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def _test_wip(self) -> None:
218         source, expected = read_data("wip")
219         sys.settrace(tracefunc)
220         mode = replace(
221             DEFAULT_MODE,
222             experimental_string_processing=False,
223             target_versions={black.TargetVersion.PY38},
224         )
225         actual = fs(source, mode=mode)
226         sys.settrace(None)
227         self.assertFormatEqual(expected, actual)
228         black.assert_equivalent(source, actual)
229         black.assert_stable(source, actual, black.FileMode())
230
231     def test_pep_572_version_detection(self) -> None:
232         source, _ = read_data("pep_572")
233         root = black.lib2to3_parse(source)
234         features = black.get_features_used(root)
235         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
236         versions = black.detect_target_versions(root)
237         self.assertIn(black.TargetVersion.PY38, versions)
238
239     def test_expression_ff(self) -> None:
240         source, expected = read_data("expression")
241         tmp_file = Path(black.dump_to_file(source))
242         try:
243             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
244             with open(tmp_file, encoding="utf8") as f:
245                 actual = f.read()
246         finally:
247             os.unlink(tmp_file)
248         self.assertFormatEqual(expected, actual)
249         with patch("black.dump_to_file", dump_to_stderr):
250             black.assert_equivalent(source, actual)
251             black.assert_stable(source, actual, DEFAULT_MODE)
252
253     def test_expression_diff(self) -> None:
254         source, _ = read_data("expression.py")
255         config = THIS_DIR / "data" / "empty_pyproject.toml"
256         expected, _ = read_data("expression.diff")
257         tmp_file = Path(black.dump_to_file(source))
258         diff_header = re.compile(
259             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
260             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
261         )
262         try:
263             result = BlackRunner().invoke(
264                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
265             )
266             self.assertEqual(result.exit_code, 0)
267         finally:
268             os.unlink(tmp_file)
269         actual = result.output
270         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
271         if expected != actual:
272             dump = black.dump_to_file(actual)
273             msg = (
274                 "Expected diff isn't equal to the actual. If you made changes to"
275                 " expression.py and this is an anticipated difference, overwrite"
276                 f" tests/data/expression.diff with {dump}"
277             )
278             self.assertEqual(expected, actual, msg)
279
280     def test_expression_diff_with_color(self) -> None:
281         source, _ = read_data("expression.py")
282         config = THIS_DIR / "data" / "empty_pyproject.toml"
283         expected, _ = read_data("expression.diff")
284         tmp_file = Path(black.dump_to_file(source))
285         try:
286             result = BlackRunner().invoke(
287                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
288             )
289         finally:
290             os.unlink(tmp_file)
291         actual = result.output
292         # We check the contents of the diff in `test_expression_diff`. All
293         # we need to check here is that color codes exist in the result.
294         self.assertIn("\033[1m", actual)
295         self.assertIn("\033[36m", actual)
296         self.assertIn("\033[32m", actual)
297         self.assertIn("\033[31m", actual)
298         self.assertIn("\033[0m", actual)
299
300     def test_detect_pos_only_arguments(self) -> None:
301         source, _ = read_data("pep_570")
302         root = black.lib2to3_parse(source)
303         features = black.get_features_used(root)
304         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
305         versions = black.detect_target_versions(root)
306         self.assertIn(black.TargetVersion.PY38, versions)
307
308     @patch("black.dump_to_file", dump_to_stderr)
309     def test_string_quotes(self) -> None:
310         source, expected = read_data("string_quotes")
311         mode = black.Mode(preview=True)
312         assert_format(source, expected, mode)
313         mode = replace(mode, string_normalization=False)
314         not_normalized = fs(source, mode=mode)
315         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
316         black.assert_equivalent(source, not_normalized)
317         black.assert_stable(source, not_normalized, mode=mode)
318
319     def test_skip_magic_trailing_comma(self) -> None:
320         source, _ = read_data("expression.py")
321         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
322         tmp_file = Path(black.dump_to_file(source))
323         diff_header = re.compile(
324             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
325             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
326         )
327         try:
328             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
329             self.assertEqual(result.exit_code, 0)
330         finally:
331             os.unlink(tmp_file)
332         actual = result.output
333         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
334         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
335         if expected != actual:
336             dump = black.dump_to_file(actual)
337             msg = (
338                 "Expected diff isn't equal to the actual. If you made changes to"
339                 " expression.py and this is an anticipated difference, overwrite"
340                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
341             )
342             self.assertEqual(expected, actual, msg)
343
344     @patch("black.dump_to_file", dump_to_stderr)
345     def test_async_as_identifier(self) -> None:
346         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
347         source, expected = read_data("async_as_identifier")
348         actual = fs(source)
349         self.assertFormatEqual(expected, actual)
350         major, minor = sys.version_info[:2]
351         if major < 3 or (major <= 3 and minor < 7):
352             black.assert_equivalent(source, actual)
353         black.assert_stable(source, actual, DEFAULT_MODE)
354         # ensure black can parse this when the target is 3.6
355         self.invokeBlack([str(source_path), "--target-version", "py36"])
356         # but not on 3.7, because async/await is no longer an identifier
357         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_python37(self) -> None:
361         source_path = (THIS_DIR / "data" / "python37.py").resolve()
362         source, expected = read_data("python37")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         major, minor = sys.version_info[:2]
366         if major > 3 or (major == 3 and minor >= 7):
367             black.assert_equivalent(source, actual)
368         black.assert_stable(source, actual, DEFAULT_MODE)
369         # ensure black can parse this when the target is 3.7
370         self.invokeBlack([str(source_path), "--target-version", "py37"])
371         # but not on 3.6, because we use async as a reserved keyword
372         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
373
374     def test_tab_comment_indentation(self) -> None:
375         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
376         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
377         self.assertFormatEqual(contents_spc, fs(contents_spc))
378         self.assertFormatEqual(contents_spc, fs(contents_tab))
379
380         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
381         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
382         self.assertFormatEqual(contents_spc, fs(contents_spc))
383         self.assertFormatEqual(contents_spc, fs(contents_tab))
384
385         # mixed tabs and spaces (valid Python 2 code)
386         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
387         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
388         self.assertFormatEqual(contents_spc, fs(contents_spc))
389         self.assertFormatEqual(contents_spc, fs(contents_tab))
390
391         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
392         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
393         self.assertFormatEqual(contents_spc, fs(contents_spc))
394         self.assertFormatEqual(contents_spc, fs(contents_tab))
395
396     def test_report_verbose(self) -> None:
397         report = Report(verbose=True)
398         out_lines = []
399         err_lines = []
400
401         def out(msg: str, **kwargs: Any) -> None:
402             out_lines.append(msg)
403
404         def err(msg: str, **kwargs: Any) -> None:
405             err_lines.append(msg)
406
407         with patch("black.output._out", out), patch("black.output._err", err):
408             report.done(Path("f1"), black.Changed.NO)
409             self.assertEqual(len(out_lines), 1)
410             self.assertEqual(len(err_lines), 0)
411             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
412             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
413             self.assertEqual(report.return_code, 0)
414             report.done(Path("f2"), black.Changed.YES)
415             self.assertEqual(len(out_lines), 2)
416             self.assertEqual(len(err_lines), 0)
417             self.assertEqual(out_lines[-1], "reformatted f2")
418             self.assertEqual(
419                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
420             )
421             report.done(Path("f3"), black.Changed.CACHED)
422             self.assertEqual(len(out_lines), 3)
423             self.assertEqual(len(err_lines), 0)
424             self.assertEqual(
425                 out_lines[-1], "f3 wasn't modified on disk since last run."
426             )
427             self.assertEqual(
428                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
429             )
430             self.assertEqual(report.return_code, 0)
431             report.check = True
432             self.assertEqual(report.return_code, 1)
433             report.check = False
434             report.failed(Path("e1"), "boom")
435             self.assertEqual(len(out_lines), 3)
436             self.assertEqual(len(err_lines), 1)
437             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
438             self.assertEqual(
439                 unstyle(str(report)),
440                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
441                 " reformat.",
442             )
443             self.assertEqual(report.return_code, 123)
444             report.done(Path("f3"), black.Changed.YES)
445             self.assertEqual(len(out_lines), 4)
446             self.assertEqual(len(err_lines), 1)
447             self.assertEqual(out_lines[-1], "reformatted f3")
448             self.assertEqual(
449                 unstyle(str(report)),
450                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
451                 " reformat.",
452             )
453             self.assertEqual(report.return_code, 123)
454             report.failed(Path("e2"), "boom")
455             self.assertEqual(len(out_lines), 4)
456             self.assertEqual(len(err_lines), 2)
457             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
458             self.assertEqual(
459                 unstyle(str(report)),
460                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
461                 " reformat.",
462             )
463             self.assertEqual(report.return_code, 123)
464             report.path_ignored(Path("wat"), "no match")
465             self.assertEqual(len(out_lines), 5)
466             self.assertEqual(len(err_lines), 2)
467             self.assertEqual(out_lines[-1], "wat ignored: no match")
468             self.assertEqual(
469                 unstyle(str(report)),
470                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
471                 " reformat.",
472             )
473             self.assertEqual(report.return_code, 123)
474             report.done(Path("f4"), black.Changed.NO)
475             self.assertEqual(len(out_lines), 6)
476             self.assertEqual(len(err_lines), 2)
477             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
478             self.assertEqual(
479                 unstyle(str(report)),
480                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
481                 " reformat.",
482             )
483             self.assertEqual(report.return_code, 123)
484             report.check = True
485             self.assertEqual(
486                 unstyle(str(report)),
487                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
488                 " would fail to reformat.",
489             )
490             report.check = False
491             report.diff = True
492             self.assertEqual(
493                 unstyle(str(report)),
494                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
495                 " would fail to reformat.",
496             )
497
498     def test_report_quiet(self) -> None:
499         report = Report(quiet=True)
500         out_lines = []
501         err_lines = []
502
503         def out(msg: str, **kwargs: Any) -> None:
504             out_lines.append(msg)
505
506         def err(msg: str, **kwargs: Any) -> None:
507             err_lines.append(msg)
508
509         with patch("black.output._out", out), patch("black.output._err", err):
510             report.done(Path("f1"), black.Changed.NO)
511             self.assertEqual(len(out_lines), 0)
512             self.assertEqual(len(err_lines), 0)
513             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
514             self.assertEqual(report.return_code, 0)
515             report.done(Path("f2"), black.Changed.YES)
516             self.assertEqual(len(out_lines), 0)
517             self.assertEqual(len(err_lines), 0)
518             self.assertEqual(
519                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
520             )
521             report.done(Path("f3"), black.Changed.CACHED)
522             self.assertEqual(len(out_lines), 0)
523             self.assertEqual(len(err_lines), 0)
524             self.assertEqual(
525                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
526             )
527             self.assertEqual(report.return_code, 0)
528             report.check = True
529             self.assertEqual(report.return_code, 1)
530             report.check = False
531             report.failed(Path("e1"), "boom")
532             self.assertEqual(len(out_lines), 0)
533             self.assertEqual(len(err_lines), 1)
534             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
535             self.assertEqual(
536                 unstyle(str(report)),
537                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
538                 " reformat.",
539             )
540             self.assertEqual(report.return_code, 123)
541             report.done(Path("f3"), black.Changed.YES)
542             self.assertEqual(len(out_lines), 0)
543             self.assertEqual(len(err_lines), 1)
544             self.assertEqual(
545                 unstyle(str(report)),
546                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
547                 " reformat.",
548             )
549             self.assertEqual(report.return_code, 123)
550             report.failed(Path("e2"), "boom")
551             self.assertEqual(len(out_lines), 0)
552             self.assertEqual(len(err_lines), 2)
553             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
554             self.assertEqual(
555                 unstyle(str(report)),
556                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
557                 " reformat.",
558             )
559             self.assertEqual(report.return_code, 123)
560             report.path_ignored(Path("wat"), "no match")
561             self.assertEqual(len(out_lines), 0)
562             self.assertEqual(len(err_lines), 2)
563             self.assertEqual(
564                 unstyle(str(report)),
565                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
566                 " reformat.",
567             )
568             self.assertEqual(report.return_code, 123)
569             report.done(Path("f4"), black.Changed.NO)
570             self.assertEqual(len(out_lines), 0)
571             self.assertEqual(len(err_lines), 2)
572             self.assertEqual(
573                 unstyle(str(report)),
574                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
575                 " reformat.",
576             )
577             self.assertEqual(report.return_code, 123)
578             report.check = True
579             self.assertEqual(
580                 unstyle(str(report)),
581                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
582                 " would fail to reformat.",
583             )
584             report.check = False
585             report.diff = True
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
589                 " would fail to reformat.",
590             )
591
592     def test_report_normal(self) -> None:
593         report = black.Report()
594         out_lines = []
595         err_lines = []
596
597         def out(msg: str, **kwargs: Any) -> None:
598             out_lines.append(msg)
599
600         def err(msg: str, **kwargs: Any) -> None:
601             err_lines.append(msg)
602
603         with patch("black.output._out", out), patch("black.output._err", err):
604             report.done(Path("f1"), black.Changed.NO)
605             self.assertEqual(len(out_lines), 0)
606             self.assertEqual(len(err_lines), 0)
607             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
608             self.assertEqual(report.return_code, 0)
609             report.done(Path("f2"), black.Changed.YES)
610             self.assertEqual(len(out_lines), 1)
611             self.assertEqual(len(err_lines), 0)
612             self.assertEqual(out_lines[-1], "reformatted f2")
613             self.assertEqual(
614                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
615             )
616             report.done(Path("f3"), black.Changed.CACHED)
617             self.assertEqual(len(out_lines), 1)
618             self.assertEqual(len(err_lines), 0)
619             self.assertEqual(out_lines[-1], "reformatted f2")
620             self.assertEqual(
621                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
622             )
623             self.assertEqual(report.return_code, 0)
624             report.check = True
625             self.assertEqual(report.return_code, 1)
626             report.check = False
627             report.failed(Path("e1"), "boom")
628             self.assertEqual(len(out_lines), 1)
629             self.assertEqual(len(err_lines), 1)
630             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
634                 " reformat.",
635             )
636             self.assertEqual(report.return_code, 123)
637             report.done(Path("f3"), black.Changed.YES)
638             self.assertEqual(len(out_lines), 2)
639             self.assertEqual(len(err_lines), 1)
640             self.assertEqual(out_lines[-1], "reformatted f3")
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
644                 " reformat.",
645             )
646             self.assertEqual(report.return_code, 123)
647             report.failed(Path("e2"), "boom")
648             self.assertEqual(len(out_lines), 2)
649             self.assertEqual(len(err_lines), 2)
650             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
651             self.assertEqual(
652                 unstyle(str(report)),
653                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
654                 " reformat.",
655             )
656             self.assertEqual(report.return_code, 123)
657             report.path_ignored(Path("wat"), "no match")
658             self.assertEqual(len(out_lines), 2)
659             self.assertEqual(len(err_lines), 2)
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.done(Path("f4"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 2)
668             self.assertEqual(len(err_lines), 2)
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
672                 " reformat.",
673             )
674             self.assertEqual(report.return_code, 123)
675             report.check = True
676             self.assertEqual(
677                 unstyle(str(report)),
678                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
679                 " would fail to reformat.",
680             )
681             report.check = False
682             report.diff = True
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
686                 " would fail to reformat.",
687             )
688
689     def test_lib2to3_parse(self) -> None:
690         with self.assertRaises(black.InvalidInput):
691             black.lib2to3_parse("invalid syntax")
692
693         straddling = "x + y"
694         black.lib2to3_parse(straddling)
695         black.lib2to3_parse(straddling, {TargetVersion.PY36})
696
697         py2_only = "print x"
698         with self.assertRaises(black.InvalidInput):
699             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
700
701         py3_only = "exec(x, end=y)"
702         black.lib2to3_parse(py3_only)
703         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
704
705     def test_get_features_used_decorator(self) -> None:
706         # Test the feature detection of new decorator syntax
707         # since this makes some test cases of test_get_features_used()
708         # fails if it fails, this is tested first so that a useful case
709         # is identified
710         simples, relaxed = read_data("decorators")
711         # skip explanation comments at the top of the file
712         for simple_test in simples.split("##")[1:]:
713             node = black.lib2to3_parse(simple_test)
714             decorator = str(node.children[0].children[0]).strip()
715             self.assertNotIn(
716                 Feature.RELAXED_DECORATORS,
717                 black.get_features_used(node),
718                 msg=(
719                     f"decorator '{decorator}' follows python<=3.8 syntax"
720                     "but is detected as 3.9+"
721                     # f"The full node is\n{node!r}"
722                 ),
723             )
724         # skip the '# output' comment at the top of the output part
725         for relaxed_test in relaxed.split("##")[1:]:
726             node = black.lib2to3_parse(relaxed_test)
727             decorator = str(node.children[0].children[0]).strip()
728             self.assertIn(
729                 Feature.RELAXED_DECORATORS,
730                 black.get_features_used(node),
731                 msg=(
732                     f"decorator '{decorator}' uses python3.9+ syntax"
733                     "but is detected as python<=3.8"
734                     # f"The full node is\n{node!r}"
735                 ),
736             )
737
738     def test_get_features_used(self) -> None:
739         node = black.lib2to3_parse("def f(*, arg): ...\n")
740         self.assertEqual(black.get_features_used(node), set())
741         node = black.lib2to3_parse("def f(*, arg,): ...\n")
742         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
743         node = black.lib2to3_parse("f(*arg,)\n")
744         self.assertEqual(
745             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
746         )
747         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
748         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
749         node = black.lib2to3_parse("123_456\n")
750         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
751         node = black.lib2to3_parse("123456\n")
752         self.assertEqual(black.get_features_used(node), set())
753         source, expected = read_data("function")
754         node = black.lib2to3_parse(source)
755         expected_features = {
756             Feature.TRAILING_COMMA_IN_CALL,
757             Feature.TRAILING_COMMA_IN_DEF,
758             Feature.F_STRINGS,
759         }
760         self.assertEqual(black.get_features_used(node), expected_features)
761         node = black.lib2to3_parse(expected)
762         self.assertEqual(black.get_features_used(node), expected_features)
763         source, expected = read_data("expression")
764         node = black.lib2to3_parse(source)
765         self.assertEqual(black.get_features_used(node), set())
766         node = black.lib2to3_parse(expected)
767         self.assertEqual(black.get_features_used(node), set())
768         node = black.lib2to3_parse("lambda a, /, b: ...")
769         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
770         node = black.lib2to3_parse("def fn(a, /, b): ...")
771         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
772         node = black.lib2to3_parse("def fn(): yield a, b")
773         self.assertEqual(black.get_features_used(node), set())
774         node = black.lib2to3_parse("def fn(): return a, b")
775         self.assertEqual(black.get_features_used(node), set())
776         node = black.lib2to3_parse("def fn(): yield *b, c")
777         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
778         node = black.lib2to3_parse("def fn(): return a, *b, c")
779         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
780         node = black.lib2to3_parse("x = a, *b, c")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("x: Any = regular")
783         self.assertEqual(black.get_features_used(node), set())
784         node = black.lib2to3_parse("x: Any = (regular, regular)")
785         self.assertEqual(black.get_features_used(node), set())
786         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
787         self.assertEqual(black.get_features_used(node), set())
788         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
789         self.assertEqual(
790             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
791         )
792
793     def test_get_features_used_for_future_flags(self) -> None:
794         for src, features in [
795             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
796             (
797                 "from __future__ import (other, annotations)",
798                 {Feature.FUTURE_ANNOTATIONS},
799             ),
800             ("a = 1 + 2\nfrom something import annotations", set()),
801             ("from __future__ import x, y", set()),
802         ]:
803             with self.subTest(src=src, features=features):
804                 node = black.lib2to3_parse(src)
805                 future_imports = black.get_future_imports(node)
806                 self.assertEqual(
807                     black.get_features_used(node, future_imports=future_imports),
808                     features,
809                 )
810
811     def test_get_future_imports(self) -> None:
812         node = black.lib2to3_parse("\n")
813         self.assertEqual(set(), black.get_future_imports(node))
814         node = black.lib2to3_parse("from __future__ import black\n")
815         self.assertEqual({"black"}, black.get_future_imports(node))
816         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
817         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
818         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
819         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
820         node = black.lib2to3_parse(
821             "from __future__ import multiple\nfrom __future__ import imports\n"
822         )
823         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
824         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
825         self.assertEqual({"black"}, black.get_future_imports(node))
826         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
827         self.assertEqual({"black"}, black.get_future_imports(node))
828         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
829         self.assertEqual(set(), black.get_future_imports(node))
830         node = black.lib2to3_parse("from some.module import black\n")
831         self.assertEqual(set(), black.get_future_imports(node))
832         node = black.lib2to3_parse(
833             "from __future__ import unicode_literals as _unicode_literals"
834         )
835         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
836         node = black.lib2to3_parse(
837             "from __future__ import unicode_literals as _lol, print"
838         )
839         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
840
841     @pytest.mark.incompatible_with_mypyc
842     def test_debug_visitor(self) -> None:
843         source, _ = read_data("debug_visitor.py")
844         expected, _ = read_data("debug_visitor.out")
845         out_lines = []
846         err_lines = []
847
848         def out(msg: str, **kwargs: Any) -> None:
849             out_lines.append(msg)
850
851         def err(msg: str, **kwargs: Any) -> None:
852             err_lines.append(msg)
853
854         with patch("black.debug.out", out):
855             DebugVisitor.show(source)
856         actual = "\n".join(out_lines) + "\n"
857         log_name = ""
858         if expected != actual:
859             log_name = black.dump_to_file(*out_lines)
860         self.assertEqual(
861             expected,
862             actual,
863             f"AST print out is different. Actual version dumped to {log_name}",
864         )
865
866     def test_format_file_contents(self) -> None:
867         empty = ""
868         mode = DEFAULT_MODE
869         with self.assertRaises(black.NothingChanged):
870             black.format_file_contents(empty, mode=mode, fast=False)
871         just_nl = "\n"
872         with self.assertRaises(black.NothingChanged):
873             black.format_file_contents(just_nl, mode=mode, fast=False)
874         same = "j = [1, 2, 3]\n"
875         with self.assertRaises(black.NothingChanged):
876             black.format_file_contents(same, mode=mode, fast=False)
877         different = "j = [1,2,3]"
878         expected = same
879         actual = black.format_file_contents(different, mode=mode, fast=False)
880         self.assertEqual(expected, actual)
881         invalid = "return if you can"
882         with self.assertRaises(black.InvalidInput) as e:
883             black.format_file_contents(invalid, mode=mode, fast=False)
884         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
885
886     def test_endmarker(self) -> None:
887         n = black.lib2to3_parse("\n")
888         self.assertEqual(n.type, black.syms.file_input)
889         self.assertEqual(len(n.children), 1)
890         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
891
892     @pytest.mark.incompatible_with_mypyc
893     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
894     def test_assertFormatEqual(self) -> None:
895         out_lines = []
896         err_lines = []
897
898         def out(msg: str, **kwargs: Any) -> None:
899             out_lines.append(msg)
900
901         def err(msg: str, **kwargs: Any) -> None:
902             err_lines.append(msg)
903
904         with patch("black.output._out", out), patch("black.output._err", err):
905             with self.assertRaises(AssertionError):
906                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
907
908         out_str = "".join(out_lines)
909         self.assertIn("Expected tree:", out_str)
910         self.assertIn("Actual tree:", out_str)
911         self.assertEqual("".join(err_lines), "")
912
913     @event_loop()
914     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
915     def test_works_in_mono_process_only_environment(self) -> None:
916         with cache_dir() as workspace:
917             for f in [
918                 (workspace / "one.py").resolve(),
919                 (workspace / "two.py").resolve(),
920             ]:
921                 f.write_text('print("hello")\n')
922             self.invokeBlack([str(workspace)])
923
924     @event_loop()
925     def test_check_diff_use_together(self) -> None:
926         with cache_dir():
927             # Files which will be reformatted.
928             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
929             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
930             # Files which will not be reformatted.
931             src2 = (THIS_DIR / "data" / "composition.py").resolve()
932             self.invokeBlack([str(src2), "--diff", "--check"])
933             # Multi file command.
934             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
935
936     def test_no_src_fails(self) -> None:
937         with cache_dir():
938             self.invokeBlack([], exit_code=1)
939
940     def test_src_and_code_fails(self) -> None:
941         with cache_dir():
942             self.invokeBlack([".", "-c", "0"], exit_code=1)
943
944     def test_broken_symlink(self) -> None:
945         with cache_dir() as workspace:
946             symlink = workspace / "broken_link.py"
947             try:
948                 symlink.symlink_to("nonexistent.py")
949             except (OSError, NotImplementedError) as e:
950                 self.skipTest(f"Can't create symlinks: {e}")
951             self.invokeBlack([str(workspace.resolve())])
952
953     def test_single_file_force_pyi(self) -> None:
954         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
955         contents, expected = read_data("force_pyi")
956         with cache_dir() as workspace:
957             path = (workspace / "file.py").resolve()
958             with open(path, "w") as fh:
959                 fh.write(contents)
960             self.invokeBlack([str(path), "--pyi"])
961             with open(path, "r") as fh:
962                 actual = fh.read()
963             # verify cache with --pyi is separate
964             pyi_cache = black.read_cache(pyi_mode)
965             self.assertIn(str(path), pyi_cache)
966             normal_cache = black.read_cache(DEFAULT_MODE)
967             self.assertNotIn(str(path), normal_cache)
968         self.assertFormatEqual(expected, actual)
969         black.assert_equivalent(contents, actual)
970         black.assert_stable(contents, actual, pyi_mode)
971
972     @event_loop()
973     def test_multi_file_force_pyi(self) -> None:
974         reg_mode = DEFAULT_MODE
975         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
976         contents, expected = read_data("force_pyi")
977         with cache_dir() as workspace:
978             paths = [
979                 (workspace / "file1.py").resolve(),
980                 (workspace / "file2.py").resolve(),
981             ]
982             for path in paths:
983                 with open(path, "w") as fh:
984                     fh.write(contents)
985             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
986             for path in paths:
987                 with open(path, "r") as fh:
988                     actual = fh.read()
989                 self.assertEqual(actual, expected)
990             # verify cache with --pyi is separate
991             pyi_cache = black.read_cache(pyi_mode)
992             normal_cache = black.read_cache(reg_mode)
993             for path in paths:
994                 self.assertIn(str(path), pyi_cache)
995                 self.assertNotIn(str(path), normal_cache)
996
997     def test_pipe_force_pyi(self) -> None:
998         source, expected = read_data("force_pyi")
999         result = CliRunner().invoke(
1000             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1001         )
1002         self.assertEqual(result.exit_code, 0)
1003         actual = result.output
1004         self.assertFormatEqual(actual, expected)
1005
1006     def test_single_file_force_py36(self) -> None:
1007         reg_mode = DEFAULT_MODE
1008         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1009         source, expected = read_data("force_py36")
1010         with cache_dir() as workspace:
1011             path = (workspace / "file.py").resolve()
1012             with open(path, "w") as fh:
1013                 fh.write(source)
1014             self.invokeBlack([str(path), *PY36_ARGS])
1015             with open(path, "r") as fh:
1016                 actual = fh.read()
1017             # verify cache with --target-version is separate
1018             py36_cache = black.read_cache(py36_mode)
1019             self.assertIn(str(path), py36_cache)
1020             normal_cache = black.read_cache(reg_mode)
1021             self.assertNotIn(str(path), normal_cache)
1022         self.assertEqual(actual, expected)
1023
1024     @event_loop()
1025     def test_multi_file_force_py36(self) -> None:
1026         reg_mode = DEFAULT_MODE
1027         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1028         source, expected = read_data("force_py36")
1029         with cache_dir() as workspace:
1030             paths = [
1031                 (workspace / "file1.py").resolve(),
1032                 (workspace / "file2.py").resolve(),
1033             ]
1034             for path in paths:
1035                 with open(path, "w") as fh:
1036                     fh.write(source)
1037             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1038             for path in paths:
1039                 with open(path, "r") as fh:
1040                     actual = fh.read()
1041                 self.assertEqual(actual, expected)
1042             # verify cache with --target-version is separate
1043             pyi_cache = black.read_cache(py36_mode)
1044             normal_cache = black.read_cache(reg_mode)
1045             for path in paths:
1046                 self.assertIn(str(path), pyi_cache)
1047                 self.assertNotIn(str(path), normal_cache)
1048
1049     def test_pipe_force_py36(self) -> None:
1050         source, expected = read_data("force_py36")
1051         result = CliRunner().invoke(
1052             black.main,
1053             ["-", "-q", "--target-version=py36"],
1054             input=BytesIO(source.encode("utf8")),
1055         )
1056         self.assertEqual(result.exit_code, 0)
1057         actual = result.output
1058         self.assertFormatEqual(actual, expected)
1059
1060     @pytest.mark.incompatible_with_mypyc
1061     def test_reformat_one_with_stdin(self) -> None:
1062         with patch(
1063             "black.format_stdin_to_stdout",
1064             return_value=lambda *args, **kwargs: black.Changed.YES,
1065         ) as fsts:
1066             report = MagicMock()
1067             path = Path("-")
1068             black.reformat_one(
1069                 path,
1070                 fast=True,
1071                 write_back=black.WriteBack.YES,
1072                 mode=DEFAULT_MODE,
1073                 report=report,
1074             )
1075             fsts.assert_called_once()
1076             report.done.assert_called_with(path, black.Changed.YES)
1077
1078     @pytest.mark.incompatible_with_mypyc
1079     def test_reformat_one_with_stdin_filename(self) -> None:
1080         with patch(
1081             "black.format_stdin_to_stdout",
1082             return_value=lambda *args, **kwargs: black.Changed.YES,
1083         ) as fsts:
1084             report = MagicMock()
1085             p = "foo.py"
1086             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1087             expected = Path(p)
1088             black.reformat_one(
1089                 path,
1090                 fast=True,
1091                 write_back=black.WriteBack.YES,
1092                 mode=DEFAULT_MODE,
1093                 report=report,
1094             )
1095             fsts.assert_called_once_with(
1096                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1097             )
1098             # __BLACK_STDIN_FILENAME__ should have been stripped
1099             report.done.assert_called_with(expected, black.Changed.YES)
1100
1101     @pytest.mark.incompatible_with_mypyc
1102     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1103         with patch(
1104             "black.format_stdin_to_stdout",
1105             return_value=lambda *args, **kwargs: black.Changed.YES,
1106         ) as fsts:
1107             report = MagicMock()
1108             p = "foo.pyi"
1109             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1110             expected = Path(p)
1111             black.reformat_one(
1112                 path,
1113                 fast=True,
1114                 write_back=black.WriteBack.YES,
1115                 mode=DEFAULT_MODE,
1116                 report=report,
1117             )
1118             fsts.assert_called_once_with(
1119                 fast=True,
1120                 write_back=black.WriteBack.YES,
1121                 mode=replace(DEFAULT_MODE, is_pyi=True),
1122             )
1123             # __BLACK_STDIN_FILENAME__ should have been stripped
1124             report.done.assert_called_with(expected, black.Changed.YES)
1125
1126     @pytest.mark.incompatible_with_mypyc
1127     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1128         with patch(
1129             "black.format_stdin_to_stdout",
1130             return_value=lambda *args, **kwargs: black.Changed.YES,
1131         ) as fsts:
1132             report = MagicMock()
1133             p = "foo.ipynb"
1134             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1135             expected = Path(p)
1136             black.reformat_one(
1137                 path,
1138                 fast=True,
1139                 write_back=black.WriteBack.YES,
1140                 mode=DEFAULT_MODE,
1141                 report=report,
1142             )
1143             fsts.assert_called_once_with(
1144                 fast=True,
1145                 write_back=black.WriteBack.YES,
1146                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1147             )
1148             # __BLACK_STDIN_FILENAME__ should have been stripped
1149             report.done.assert_called_with(expected, black.Changed.YES)
1150
1151     @pytest.mark.incompatible_with_mypyc
1152     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1153         with patch(
1154             "black.format_stdin_to_stdout",
1155             return_value=lambda *args, **kwargs: black.Changed.YES,
1156         ) as fsts:
1157             report = MagicMock()
1158             # Even with an existing file, since we are forcing stdin, black
1159             # should output to stdout and not modify the file inplace
1160             p = Path(str(THIS_DIR / "data/collections.py"))
1161             # Make sure is_file actually returns True
1162             self.assertTrue(p.is_file())
1163             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1164             expected = Path(p)
1165             black.reformat_one(
1166                 path,
1167                 fast=True,
1168                 write_back=black.WriteBack.YES,
1169                 mode=DEFAULT_MODE,
1170                 report=report,
1171             )
1172             fsts.assert_called_once()
1173             # __BLACK_STDIN_FILENAME__ should have been stripped
1174             report.done.assert_called_with(expected, black.Changed.YES)
1175
1176     def test_reformat_one_with_stdin_empty(self) -> None:
1177         output = io.StringIO()
1178         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1179             try:
1180                 black.format_stdin_to_stdout(
1181                     fast=True,
1182                     content="",
1183                     write_back=black.WriteBack.YES,
1184                     mode=DEFAULT_MODE,
1185                 )
1186             except io.UnsupportedOperation:
1187                 pass  # StringIO does not support detach
1188             assert output.getvalue() == ""
1189
1190     def test_invalid_cli_regex(self) -> None:
1191         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1192             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1193
1194     def test_required_version_matches_version(self) -> None:
1195         self.invokeBlack(
1196             ["--required-version", black.__version__, "-c", "0"],
1197             exit_code=0,
1198             ignore_config=True,
1199         )
1200
1201     def test_required_version_does_not_match_version(self) -> None:
1202         result = BlackRunner().invoke(
1203             black.main,
1204             ["--required-version", "20.99b", "-c", "0"],
1205         )
1206         self.assertEqual(result.exit_code, 1)
1207         self.assertIn("required version", result.stderr)
1208
1209     def test_preserves_line_endings(self) -> None:
1210         with TemporaryDirectory() as workspace:
1211             test_file = Path(workspace) / "test.py"
1212             for nl in ["\n", "\r\n"]:
1213                 contents = nl.join(["def f(  ):", "    pass"])
1214                 test_file.write_bytes(contents.encode())
1215                 ff(test_file, write_back=black.WriteBack.YES)
1216                 updated_contents: bytes = test_file.read_bytes()
1217                 self.assertIn(nl.encode(), updated_contents)
1218                 if nl == "\n":
1219                     self.assertNotIn(b"\r\n", updated_contents)
1220
1221     def test_preserves_line_endings_via_stdin(self) -> None:
1222         for nl in ["\n", "\r\n"]:
1223             contents = nl.join(["def f(  ):", "    pass"])
1224             runner = BlackRunner()
1225             result = runner.invoke(
1226                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1227             )
1228             self.assertEqual(result.exit_code, 0)
1229             output = result.stdout_bytes
1230             self.assertIn(nl.encode("utf8"), output)
1231             if nl == "\n":
1232                 self.assertNotIn(b"\r\n", output)
1233
1234     def test_assert_equivalent_different_asts(self) -> None:
1235         with self.assertRaises(AssertionError):
1236             black.assert_equivalent("{}", "None")
1237
1238     def test_shhh_click(self) -> None:
1239         try:
1240             from click import _unicodefun
1241         except ModuleNotFoundError:
1242             self.skipTest("Incompatible Click version")
1243         if not hasattr(_unicodefun, "_verify_python3_env"):
1244             self.skipTest("Incompatible Click version")
1245         # First, let's see if Click is crashing with a preferred ASCII charset.
1246         with patch("locale.getpreferredencoding") as gpe:
1247             gpe.return_value = "ASCII"
1248             with self.assertRaises(RuntimeError):
1249                 _unicodefun._verify_python3_env()  # type: ignore
1250         # Now, let's silence Click...
1251         black.patch_click()
1252         # ...and confirm it's silent.
1253         with patch("locale.getpreferredencoding") as gpe:
1254             gpe.return_value = "ASCII"
1255             try:
1256                 _unicodefun._verify_python3_env()  # type: ignore
1257             except RuntimeError as re:
1258                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1259
1260     def test_root_logger_not_used_directly(self) -> None:
1261         def fail(*args: Any, **kwargs: Any) -> None:
1262             self.fail("Record created with root logger")
1263
1264         with patch.multiple(
1265             logging.root,
1266             debug=fail,
1267             info=fail,
1268             warning=fail,
1269             error=fail,
1270             critical=fail,
1271             log=fail,
1272         ):
1273             ff(THIS_DIR / "util.py")
1274
1275     def test_invalid_config_return_code(self) -> None:
1276         tmp_file = Path(black.dump_to_file())
1277         try:
1278             tmp_config = Path(black.dump_to_file())
1279             tmp_config.unlink()
1280             args = ["--config", str(tmp_config), str(tmp_file)]
1281             self.invokeBlack(args, exit_code=2, ignore_config=False)
1282         finally:
1283             tmp_file.unlink()
1284
1285     def test_parse_pyproject_toml(self) -> None:
1286         test_toml_file = THIS_DIR / "test.toml"
1287         config = black.parse_pyproject_toml(str(test_toml_file))
1288         self.assertEqual(config["verbose"], 1)
1289         self.assertEqual(config["check"], "no")
1290         self.assertEqual(config["diff"], "y")
1291         self.assertEqual(config["color"], True)
1292         self.assertEqual(config["line_length"], 79)
1293         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1294         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1295         self.assertEqual(config["exclude"], r"\.pyi?$")
1296         self.assertEqual(config["include"], r"\.py?$")
1297
1298     def test_read_pyproject_toml(self) -> None:
1299         test_toml_file = THIS_DIR / "test.toml"
1300         fake_ctx = FakeContext()
1301         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1302         config = fake_ctx.default_map
1303         self.assertEqual(config["verbose"], "1")
1304         self.assertEqual(config["check"], "no")
1305         self.assertEqual(config["diff"], "y")
1306         self.assertEqual(config["color"], "True")
1307         self.assertEqual(config["line_length"], "79")
1308         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1309         self.assertEqual(config["exclude"], r"\.pyi?$")
1310         self.assertEqual(config["include"], r"\.py?$")
1311
1312     @pytest.mark.incompatible_with_mypyc
1313     def test_find_project_root(self) -> None:
1314         with TemporaryDirectory() as workspace:
1315             root = Path(workspace)
1316             test_dir = root / "test"
1317             test_dir.mkdir()
1318
1319             src_dir = root / "src"
1320             src_dir.mkdir()
1321
1322             root_pyproject = root / "pyproject.toml"
1323             root_pyproject.touch()
1324             src_pyproject = src_dir / "pyproject.toml"
1325             src_pyproject.touch()
1326             src_python = src_dir / "foo.py"
1327             src_python.touch()
1328
1329             self.assertEqual(
1330                 black.find_project_root((src_dir, test_dir)),
1331                 (root.resolve(), "pyproject.toml"),
1332             )
1333             self.assertEqual(
1334                 black.find_project_root((src_dir,)),
1335                 (src_dir.resolve(), "pyproject.toml"),
1336             )
1337             self.assertEqual(
1338                 black.find_project_root((src_python,)),
1339                 (src_dir.resolve(), "pyproject.toml"),
1340             )
1341
1342     @patch(
1343         "black.files.find_user_pyproject_toml",
1344         black.files.find_user_pyproject_toml.__wrapped__,
1345     )
1346     def test_find_user_pyproject_toml_linux(self) -> None:
1347         if system() == "Windows":
1348             return
1349
1350         # Test if XDG_CONFIG_HOME is checked
1351         with TemporaryDirectory() as workspace:
1352             tmp_user_config = Path(workspace) / "black"
1353             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1354                 self.assertEqual(
1355                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1356                 )
1357
1358         # Test fallback for XDG_CONFIG_HOME
1359         with patch.dict("os.environ"):
1360             os.environ.pop("XDG_CONFIG_HOME", None)
1361             fallback_user_config = Path("~/.config").expanduser() / "black"
1362             self.assertEqual(
1363                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1364             )
1365
1366     def test_find_user_pyproject_toml_windows(self) -> None:
1367         if system() != "Windows":
1368             return
1369
1370         user_config_path = Path.home() / ".black"
1371         self.assertEqual(
1372             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1373         )
1374
1375     def test_bpo_33660_workaround(self) -> None:
1376         if system() == "Windows":
1377             return
1378
1379         # https://bugs.python.org/issue33660
1380         root = Path("/")
1381         with change_directory(root):
1382             path = Path("workspace") / "project"
1383             report = black.Report(verbose=True)
1384             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1385             self.assertEqual(normalized_path, "workspace/project")
1386
1387     def test_newline_comment_interaction(self) -> None:
1388         source = "class A:\\\r\n# type: ignore\n pass\n"
1389         output = black.format_str(source, mode=DEFAULT_MODE)
1390         black.assert_stable(source, output, mode=DEFAULT_MODE)
1391
1392     def test_bpo_2142_workaround(self) -> None:
1393
1394         # https://bugs.python.org/issue2142
1395
1396         source, _ = read_data("missing_final_newline.py")
1397         # read_data adds a trailing newline
1398         source = source.rstrip()
1399         expected, _ = read_data("missing_final_newline.diff")
1400         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1401         diff_header = re.compile(
1402             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1403             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1404         )
1405         try:
1406             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1407             self.assertEqual(result.exit_code, 0)
1408         finally:
1409             os.unlink(tmp_file)
1410         actual = result.output
1411         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1412         self.assertEqual(actual, expected)
1413
1414     @staticmethod
1415     def compare_results(
1416         result: click.testing.Result, expected_value: str, expected_exit_code: int
1417     ) -> None:
1418         """Helper method to test the value and exit code of a click Result."""
1419         assert (
1420             result.output == expected_value
1421         ), "The output did not match the expected value."
1422         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1423
1424     def test_code_option(self) -> None:
1425         """Test the code option with no changes."""
1426         code = 'print("Hello world")\n'
1427         args = ["--code", code]
1428         result = CliRunner().invoke(black.main, args)
1429
1430         self.compare_results(result, code, 0)
1431
1432     def test_code_option_changed(self) -> None:
1433         """Test the code option when changes are required."""
1434         code = "print('hello world')"
1435         formatted = black.format_str(code, mode=DEFAULT_MODE)
1436
1437         args = ["--code", code]
1438         result = CliRunner().invoke(black.main, args)
1439
1440         self.compare_results(result, formatted, 0)
1441
1442     def test_code_option_check(self) -> None:
1443         """Test the code option when check is passed."""
1444         args = ["--check", "--code", 'print("Hello world")\n']
1445         result = CliRunner().invoke(black.main, args)
1446         self.compare_results(result, "", 0)
1447
1448     def test_code_option_check_changed(self) -> None:
1449         """Test the code option when changes are required, and check is passed."""
1450         args = ["--check", "--code", "print('hello world')"]
1451         result = CliRunner().invoke(black.main, args)
1452         self.compare_results(result, "", 1)
1453
1454     def test_code_option_diff(self) -> None:
1455         """Test the code option when diff is passed."""
1456         code = "print('hello world')"
1457         formatted = black.format_str(code, mode=DEFAULT_MODE)
1458         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1459
1460         args = ["--diff", "--code", code]
1461         result = CliRunner().invoke(black.main, args)
1462
1463         # Remove time from diff
1464         output = DIFF_TIME.sub("", result.output)
1465
1466         assert output == result_diff, "The output did not match the expected value."
1467         assert result.exit_code == 0, "The exit code is incorrect."
1468
1469     def test_code_option_color_diff(self) -> None:
1470         """Test the code option when color and diff are passed."""
1471         code = "print('hello world')"
1472         formatted = black.format_str(code, mode=DEFAULT_MODE)
1473
1474         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1475         result_diff = color_diff(result_diff)
1476
1477         args = ["--diff", "--color", "--code", code]
1478         result = CliRunner().invoke(black.main, args)
1479
1480         # Remove time from diff
1481         output = DIFF_TIME.sub("", result.output)
1482
1483         assert output == result_diff, "The output did not match the expected value."
1484         assert result.exit_code == 0, "The exit code is incorrect."
1485
1486     @pytest.mark.incompatible_with_mypyc
1487     def test_code_option_safe(self) -> None:
1488         """Test that the code option throws an error when the sanity checks fail."""
1489         # Patch black.assert_equivalent to ensure the sanity checks fail
1490         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1491             code = 'print("Hello world")'
1492             error_msg = f"{code}\nerror: cannot format <string>: \n"
1493
1494             args = ["--safe", "--code", code]
1495             result = CliRunner().invoke(black.main, args)
1496
1497             self.compare_results(result, error_msg, 123)
1498
1499     def test_code_option_fast(self) -> None:
1500         """Test that the code option ignores errors when the sanity checks fail."""
1501         # Patch black.assert_equivalent to ensure the sanity checks fail
1502         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1503             code = 'print("Hello world")'
1504             formatted = black.format_str(code, mode=DEFAULT_MODE)
1505
1506             args = ["--fast", "--code", code]
1507             result = CliRunner().invoke(black.main, args)
1508
1509             self.compare_results(result, formatted, 0)
1510
1511     @pytest.mark.incompatible_with_mypyc
1512     def test_code_option_config(self) -> None:
1513         """
1514         Test that the code option finds the pyproject.toml in the current directory.
1515         """
1516         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1517             args = ["--code", "print"]
1518             # This is the only directory known to contain a pyproject.toml
1519             with change_directory(PROJECT_ROOT):
1520                 CliRunner().invoke(black.main, args)
1521                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1522
1523             assert (
1524                 len(parse.mock_calls) >= 1
1525             ), "Expected config parse to be called with the current directory."
1526
1527             _, call_args, _ = parse.mock_calls[0]
1528             assert (
1529                 call_args[0].lower() == str(pyproject_path).lower()
1530             ), "Incorrect config loaded."
1531
1532     @pytest.mark.incompatible_with_mypyc
1533     def test_code_option_parent_config(self) -> None:
1534         """
1535         Test that the code option finds the pyproject.toml in the parent directory.
1536         """
1537         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1538             with change_directory(THIS_DIR):
1539                 args = ["--code", "print"]
1540                 CliRunner().invoke(black.main, args)
1541
1542                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1543                 assert (
1544                     len(parse.mock_calls) >= 1
1545                 ), "Expected config parse to be called with the current directory."
1546
1547                 _, call_args, _ = parse.mock_calls[0]
1548                 assert (
1549                     call_args[0].lower() == str(pyproject_path).lower()
1550                 ), "Incorrect config loaded."
1551
1552     def test_for_handled_unexpected_eof_error(self) -> None:
1553         """
1554         Test that an unexpected EOF SyntaxError is nicely presented.
1555         """
1556         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1557             black.lib2to3_parse("print(", {})
1558
1559         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1560
1561     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1562         with pytest.raises(AssertionError) as err:
1563             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1564
1565         err.match("--safe")
1566         # Unfortunately the SyntaxError message has changed in newer versions so we
1567         # can't match it directly.
1568         err.match("invalid character")
1569         err.match(r"\(<unknown>, line 1\)")
1570
1571
1572 class TestCaching:
1573     def test_get_cache_dir(
1574         self,
1575         tmp_path: Path,
1576         monkeypatch: pytest.MonkeyPatch,
1577     ) -> None:
1578         # Create multiple cache directories
1579         workspace1 = tmp_path / "ws1"
1580         workspace1.mkdir()
1581         workspace2 = tmp_path / "ws2"
1582         workspace2.mkdir()
1583
1584         # Force user_cache_dir to use the temporary directory for easier assertions
1585         patch_user_cache_dir = patch(
1586             target="black.cache.user_cache_dir",
1587             autospec=True,
1588             return_value=str(workspace1),
1589         )
1590
1591         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1592         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1593         with patch_user_cache_dir:
1594             assert get_cache_dir() == workspace1
1595
1596         # If it is set, use the path provided in the env var.
1597         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1598         assert get_cache_dir() == workspace2
1599
1600     def test_cache_broken_file(self) -> None:
1601         mode = DEFAULT_MODE
1602         with cache_dir() as workspace:
1603             cache_file = get_cache_file(mode)
1604             cache_file.write_text("this is not a pickle")
1605             assert black.read_cache(mode) == {}
1606             src = (workspace / "test.py").resolve()
1607             src.write_text("print('hello')")
1608             invokeBlack([str(src)])
1609             cache = black.read_cache(mode)
1610             assert str(src) in cache
1611
1612     def test_cache_single_file_already_cached(self) -> None:
1613         mode = DEFAULT_MODE
1614         with cache_dir() as workspace:
1615             src = (workspace / "test.py").resolve()
1616             src.write_text("print('hello')")
1617             black.write_cache({}, [src], mode)
1618             invokeBlack([str(src)])
1619             assert src.read_text() == "print('hello')"
1620
1621     @event_loop()
1622     def test_cache_multiple_files(self) -> None:
1623         mode = DEFAULT_MODE
1624         with cache_dir() as workspace, patch(
1625             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1626         ):
1627             one = (workspace / "one.py").resolve()
1628             with one.open("w") as fobj:
1629                 fobj.write("print('hello')")
1630             two = (workspace / "two.py").resolve()
1631             with two.open("w") as fobj:
1632                 fobj.write("print('hello')")
1633             black.write_cache({}, [one], mode)
1634             invokeBlack([str(workspace)])
1635             with one.open("r") as fobj:
1636                 assert fobj.read() == "print('hello')"
1637             with two.open("r") as fobj:
1638                 assert fobj.read() == 'print("hello")\n'
1639             cache = black.read_cache(mode)
1640             assert str(one) in cache
1641             assert str(two) in cache
1642
1643     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1644     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1645         mode = DEFAULT_MODE
1646         with cache_dir() as workspace:
1647             src = (workspace / "test.py").resolve()
1648             with src.open("w") as fobj:
1649                 fobj.write("print('hello')")
1650             with patch("black.read_cache") as read_cache, patch(
1651                 "black.write_cache"
1652             ) as write_cache:
1653                 cmd = [str(src), "--diff"]
1654                 if color:
1655                     cmd.append("--color")
1656                 invokeBlack(cmd)
1657                 cache_file = get_cache_file(mode)
1658                 assert cache_file.exists() is False
1659                 write_cache.assert_not_called()
1660                 read_cache.assert_not_called()
1661
1662     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1663     @event_loop()
1664     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1665         with cache_dir() as workspace:
1666             for tag in range(0, 4):
1667                 src = (workspace / f"test{tag}.py").resolve()
1668                 with src.open("w") as fobj:
1669                     fobj.write("print('hello')")
1670             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1671                 cmd = ["--diff", str(workspace)]
1672                 if color:
1673                     cmd.append("--color")
1674                 invokeBlack(cmd, exit_code=0)
1675                 # this isn't quite doing what we want, but if it _isn't_
1676                 # called then we cannot be using the lock it provides
1677                 mgr.assert_called()
1678
1679     def test_no_cache_when_stdin(self) -> None:
1680         mode = DEFAULT_MODE
1681         with cache_dir():
1682             result = CliRunner().invoke(
1683                 black.main, ["-"], input=BytesIO(b"print('hello')")
1684             )
1685             assert not result.exit_code
1686             cache_file = get_cache_file(mode)
1687             assert not cache_file.exists()
1688
1689     def test_read_cache_no_cachefile(self) -> None:
1690         mode = DEFAULT_MODE
1691         with cache_dir():
1692             assert black.read_cache(mode) == {}
1693
1694     def test_write_cache_read_cache(self) -> None:
1695         mode = DEFAULT_MODE
1696         with cache_dir() as workspace:
1697             src = (workspace / "test.py").resolve()
1698             src.touch()
1699             black.write_cache({}, [src], mode)
1700             cache = black.read_cache(mode)
1701             assert str(src) in cache
1702             assert cache[str(src)] == black.get_cache_info(src)
1703
1704     def test_filter_cached(self) -> None:
1705         with TemporaryDirectory() as workspace:
1706             path = Path(workspace)
1707             uncached = (path / "uncached").resolve()
1708             cached = (path / "cached").resolve()
1709             cached_but_changed = (path / "changed").resolve()
1710             uncached.touch()
1711             cached.touch()
1712             cached_but_changed.touch()
1713             cache = {
1714                 str(cached): black.get_cache_info(cached),
1715                 str(cached_but_changed): (0.0, 0),
1716             }
1717             todo, done = black.filter_cached(
1718                 cache, {uncached, cached, cached_but_changed}
1719             )
1720             assert todo == {uncached, cached_but_changed}
1721             assert done == {cached}
1722
1723     def test_write_cache_creates_directory_if_needed(self) -> None:
1724         mode = DEFAULT_MODE
1725         with cache_dir(exists=False) as workspace:
1726             assert not workspace.exists()
1727             black.write_cache({}, [], mode)
1728             assert workspace.exists()
1729
1730     @event_loop()
1731     def test_failed_formatting_does_not_get_cached(self) -> None:
1732         mode = DEFAULT_MODE
1733         with cache_dir() as workspace, patch(
1734             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1735         ):
1736             failing = (workspace / "failing.py").resolve()
1737             with failing.open("w") as fobj:
1738                 fobj.write("not actually python")
1739             clean = (workspace / "clean.py").resolve()
1740             with clean.open("w") as fobj:
1741                 fobj.write('print("hello")\n')
1742             invokeBlack([str(workspace)], exit_code=123)
1743             cache = black.read_cache(mode)
1744             assert str(failing) not in cache
1745             assert str(clean) in cache
1746
1747     def test_write_cache_write_fail(self) -> None:
1748         mode = DEFAULT_MODE
1749         with cache_dir(), patch.object(Path, "open") as mock:
1750             mock.side_effect = OSError
1751             black.write_cache({}, [], mode)
1752
1753     def test_read_cache_line_lengths(self) -> None:
1754         mode = DEFAULT_MODE
1755         short_mode = replace(DEFAULT_MODE, line_length=1)
1756         with cache_dir() as workspace:
1757             path = (workspace / "file.py").resolve()
1758             path.touch()
1759             black.write_cache({}, [path], mode)
1760             one = black.read_cache(mode)
1761             assert str(path) in one
1762             two = black.read_cache(short_mode)
1763             assert str(path) not in two
1764
1765
1766 def assert_collected_sources(
1767     src: Sequence[Union[str, Path]],
1768     expected: Sequence[Union[str, Path]],
1769     *,
1770     ctx: Optional[FakeContext] = None,
1771     exclude: Optional[str] = None,
1772     include: Optional[str] = None,
1773     extend_exclude: Optional[str] = None,
1774     force_exclude: Optional[str] = None,
1775     stdin_filename: Optional[str] = None,
1776 ) -> None:
1777     gs_src = tuple(str(Path(s)) for s in src)
1778     gs_expected = [Path(s) for s in expected]
1779     gs_exclude = None if exclude is None else compile_pattern(exclude)
1780     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1781     gs_extend_exclude = (
1782         None if extend_exclude is None else compile_pattern(extend_exclude)
1783     )
1784     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1785     collected = black.get_sources(
1786         ctx=ctx or FakeContext(),
1787         src=gs_src,
1788         quiet=False,
1789         verbose=False,
1790         include=gs_include,
1791         exclude=gs_exclude,
1792         extend_exclude=gs_extend_exclude,
1793         force_exclude=gs_force_exclude,
1794         report=black.Report(),
1795         stdin_filename=stdin_filename,
1796     )
1797     assert sorted(collected) == sorted(gs_expected)
1798
1799
1800 class TestFileCollection:
1801     def test_include_exclude(self) -> None:
1802         path = THIS_DIR / "data" / "include_exclude_tests"
1803         src = [path]
1804         expected = [
1805             Path(path / "b/dont_exclude/a.py"),
1806             Path(path / "b/dont_exclude/a.pyi"),
1807         ]
1808         assert_collected_sources(
1809             src,
1810             expected,
1811             include=r"\.pyi?$",
1812             exclude=r"/exclude/|/\.definitely_exclude/",
1813         )
1814
1815     def test_gitignore_used_as_default(self) -> None:
1816         base = Path(DATA_DIR / "include_exclude_tests")
1817         expected = [
1818             base / "b/.definitely_exclude/a.py",
1819             base / "b/.definitely_exclude/a.pyi",
1820         ]
1821         src = [base / "b/"]
1822         ctx = FakeContext()
1823         ctx.obj["root"] = base
1824         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1825
1826     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1827     def test_exclude_for_issue_1572(self) -> None:
1828         # Exclude shouldn't touch files that were explicitly given to Black through the
1829         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1830         # https://github.com/psf/black/issues/1572
1831         path = DATA_DIR / "include_exclude_tests"
1832         src = [path / "b/exclude/a.py"]
1833         expected = [path / "b/exclude/a.py"]
1834         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1835
1836     def test_gitignore_exclude(self) -> None:
1837         path = THIS_DIR / "data" / "include_exclude_tests"
1838         include = re.compile(r"\.pyi?$")
1839         exclude = re.compile(r"")
1840         report = black.Report()
1841         gitignore = PathSpec.from_lines(
1842             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1843         )
1844         sources: List[Path] = []
1845         expected = [
1846             Path(path / "b/dont_exclude/a.py"),
1847             Path(path / "b/dont_exclude/a.pyi"),
1848         ]
1849         this_abs = THIS_DIR.resolve()
1850         sources.extend(
1851             black.gen_python_files(
1852                 path.iterdir(),
1853                 this_abs,
1854                 include,
1855                 exclude,
1856                 None,
1857                 None,
1858                 report,
1859                 gitignore,
1860                 verbose=False,
1861                 quiet=False,
1862             )
1863         )
1864         assert sorted(expected) == sorted(sources)
1865
1866     def test_nested_gitignore(self) -> None:
1867         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1868         include = re.compile(r"\.pyi?$")
1869         exclude = re.compile(r"")
1870         root_gitignore = black.files.get_gitignore(path)
1871         report = black.Report()
1872         expected: List[Path] = [
1873             Path(path / "x.py"),
1874             Path(path / "root/b.py"),
1875             Path(path / "root/c.py"),
1876             Path(path / "root/child/c.py"),
1877         ]
1878         this_abs = THIS_DIR.resolve()
1879         sources = list(
1880             black.gen_python_files(
1881                 path.iterdir(),
1882                 this_abs,
1883                 include,
1884                 exclude,
1885                 None,
1886                 None,
1887                 report,
1888                 root_gitignore,
1889                 verbose=False,
1890                 quiet=False,
1891             )
1892         )
1893         assert sorted(expected) == sorted(sources)
1894
1895     def test_invalid_gitignore(self) -> None:
1896         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1897         empty_config = path / "pyproject.toml"
1898         result = BlackRunner().invoke(
1899             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1900         )
1901         assert result.exit_code == 1
1902         assert result.stderr_bytes is not None
1903
1904         gitignore = path / ".gitignore"
1905         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1906
1907     def test_invalid_nested_gitignore(self) -> None:
1908         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1909         empty_config = path / "pyproject.toml"
1910         result = BlackRunner().invoke(
1911             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1912         )
1913         assert result.exit_code == 1
1914         assert result.stderr_bytes is not None
1915
1916         gitignore = path / "a" / ".gitignore"
1917         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1918
1919     def test_empty_include(self) -> None:
1920         path = DATA_DIR / "include_exclude_tests"
1921         src = [path]
1922         expected = [
1923             Path(path / "b/exclude/a.pie"),
1924             Path(path / "b/exclude/a.py"),
1925             Path(path / "b/exclude/a.pyi"),
1926             Path(path / "b/dont_exclude/a.pie"),
1927             Path(path / "b/dont_exclude/a.py"),
1928             Path(path / "b/dont_exclude/a.pyi"),
1929             Path(path / "b/.definitely_exclude/a.pie"),
1930             Path(path / "b/.definitely_exclude/a.py"),
1931             Path(path / "b/.definitely_exclude/a.pyi"),
1932             Path(path / ".gitignore"),
1933             Path(path / "pyproject.toml"),
1934         ]
1935         # Setting exclude explicitly to an empty string to block .gitignore usage.
1936         assert_collected_sources(src, expected, include="", exclude="")
1937
1938     def test_extend_exclude(self) -> None:
1939         path = DATA_DIR / "include_exclude_tests"
1940         src = [path]
1941         expected = [
1942             Path(path / "b/exclude/a.py"),
1943             Path(path / "b/dont_exclude/a.py"),
1944         ]
1945         assert_collected_sources(
1946             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1947         )
1948
1949     @pytest.mark.incompatible_with_mypyc
1950     def test_symlink_out_of_root_directory(self) -> None:
1951         path = MagicMock()
1952         root = THIS_DIR.resolve()
1953         child = MagicMock()
1954         include = re.compile(black.DEFAULT_INCLUDES)
1955         exclude = re.compile(black.DEFAULT_EXCLUDES)
1956         report = black.Report()
1957         gitignore = PathSpec.from_lines("gitwildmatch", [])
1958         # `child` should behave like a symlink which resolved path is clearly
1959         # outside of the `root` directory.
1960         path.iterdir.return_value = [child]
1961         child.resolve.return_value = Path("/a/b/c")
1962         child.as_posix.return_value = "/a/b/c"
1963         child.is_symlink.return_value = True
1964         try:
1965             list(
1966                 black.gen_python_files(
1967                     path.iterdir(),
1968                     root,
1969                     include,
1970                     exclude,
1971                     None,
1972                     None,
1973                     report,
1974                     gitignore,
1975                     verbose=False,
1976                     quiet=False,
1977                 )
1978             )
1979         except ValueError as ve:
1980             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1981         path.iterdir.assert_called_once()
1982         child.resolve.assert_called_once()
1983         child.is_symlink.assert_called_once()
1984         # `child` should behave like a strange file which resolved path is clearly
1985         # outside of the `root` directory.
1986         child.is_symlink.return_value = False
1987         with pytest.raises(ValueError):
1988             list(
1989                 black.gen_python_files(
1990                     path.iterdir(),
1991                     root,
1992                     include,
1993                     exclude,
1994                     None,
1995                     None,
1996                     report,
1997                     gitignore,
1998                     verbose=False,
1999                     quiet=False,
2000                 )
2001             )
2002         path.iterdir.assert_called()
2003         assert path.iterdir.call_count == 2
2004         child.resolve.assert_called()
2005         assert child.resolve.call_count == 2
2006         child.is_symlink.assert_called()
2007         assert child.is_symlink.call_count == 2
2008
2009     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2010     def test_get_sources_with_stdin(self) -> None:
2011         src = ["-"]
2012         expected = ["-"]
2013         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2014
2015     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2016     def test_get_sources_with_stdin_filename(self) -> None:
2017         src = ["-"]
2018         stdin_filename = str(THIS_DIR / "data/collections.py")
2019         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2020         assert_collected_sources(
2021             src,
2022             expected,
2023             exclude=r"/exclude/a\.py",
2024             stdin_filename=stdin_filename,
2025         )
2026
2027     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2028     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2029         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2030         # file being passed directly. This is the same as
2031         # test_exclude_for_issue_1572
2032         path = DATA_DIR / "include_exclude_tests"
2033         src = ["-"]
2034         stdin_filename = str(path / "b/exclude/a.py")
2035         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2036         assert_collected_sources(
2037             src,
2038             expected,
2039             exclude=r"/exclude/|a\.py",
2040             stdin_filename=stdin_filename,
2041         )
2042
2043     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2044     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2045         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2046         # file being passed directly. This is the same as
2047         # test_exclude_for_issue_1572
2048         src = ["-"]
2049         path = THIS_DIR / "data" / "include_exclude_tests"
2050         stdin_filename = str(path / "b/exclude/a.py")
2051         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2052         assert_collected_sources(
2053             src,
2054             expected,
2055             extend_exclude=r"/exclude/|a\.py",
2056             stdin_filename=stdin_filename,
2057         )
2058
2059     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2060     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2061         # Force exclude should exclude the file when passing it through
2062         # stdin_filename
2063         path = THIS_DIR / "data" / "include_exclude_tests"
2064         stdin_filename = str(path / "b/exclude/a.py")
2065         assert_collected_sources(
2066             src=["-"],
2067             expected=[],
2068             force_exclude=r"/exclude/|a\.py",
2069             stdin_filename=stdin_filename,
2070         )
2071
2072
2073 try:
2074     with open(black.__file__, "r", encoding="utf-8") as _bf:
2075         black_source_lines = _bf.readlines()
2076 except UnicodeDecodeError:
2077     if not black.COMPILED:
2078         raise
2079
2080
2081 def tracefunc(
2082     frame: types.FrameType, event: str, arg: Any
2083 ) -> Callable[[types.FrameType, str, Any], Any]:
2084     """Show function calls `from black/__init__.py` as they happen.
2085
2086     Register this with `sys.settrace()` in a test you're debugging.
2087     """
2088     if event != "call":
2089         return tracefunc
2090
2091     stack = len(inspect.stack()) - 19
2092     stack *= 2
2093     filename = frame.f_code.co_filename
2094     lineno = frame.f_lineno
2095     func_sig_lineno = lineno - 1
2096     funcname = black_source_lines[func_sig_lineno].strip()
2097     while funcname.startswith("@"):
2098         func_sig_lineno += 1
2099         funcname = black_source_lines[func_sig_lineno].strip()
2100     if "black/__init__.py" in filename:
2101         print(f"{' ' * stack}{lineno}:{funcname}")
2102     return tracefunc