]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

correct Vim integration code (#2853)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager, redirect_stderr
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
67 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
68 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
69 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
70 T = TypeVar("T")
71 R = TypeVar("R")
72
73 # Match the time output in a diff, but nothing else
74 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.cache.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop() -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     loop = policy.new_event_loop()
91     asyncio.set_event_loop(loop)
92     try:
93         yield
94
95     finally:
96         loop.close()
97
98
99 class FakeContext(click.Context):
100     """A fake click Context for when calling functions that need it."""
101
102     def __init__(self) -> None:
103         self.default_map: Dict[str, Any] = {}
104         # Dummy root, since most of the tests don't care about it
105         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
106
107
108 class FakeParameter(click.Parameter):
109     """A fake click Parameter for when calling functions that need it."""
110
111     def __init__(self) -> None:
112         pass
113
114
115 class BlackRunner(CliRunner):
116     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
117
118     def __init__(self) -> None:
119         super().__init__(mix_stderr=False)
120
121
122 def invokeBlack(
123     args: List[str], exit_code: int = 0, ignore_config: bool = True
124 ) -> None:
125     runner = BlackRunner()
126     if ignore_config:
127         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
128     result = runner.invoke(black.main, args, catch_exceptions=False)
129     assert result.stdout_bytes is not None
130     assert result.stderr_bytes is not None
131     msg = (
132         f"Failed with args: {args}\n"
133         f"stdout: {result.stdout_bytes.decode()!r}\n"
134         f"stderr: {result.stderr_bytes.decode()!r}\n"
135         f"exception: {result.exception}"
136     )
137     assert result.exit_code == exit_code, msg
138
139
140 class BlackTestCase(BlackBaseTestCase):
141     invokeBlack = staticmethod(invokeBlack)
142
143     def test_empty_ff(self) -> None:
144         expected = ""
145         tmp_file = Path(black.dump_to_file())
146         try:
147             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
148             with open(tmp_file, encoding="utf8") as f:
149                 actual = f.read()
150         finally:
151             os.unlink(tmp_file)
152         self.assertFormatEqual(expected, actual)
153
154     def test_experimental_string_processing_warns(self) -> None:
155         self.assertWarns(
156             black.mode.Deprecated, black.Mode, experimental_string_processing=True
157         )
158
159     def test_piping(self) -> None:
160         source, expected = read_data("src/black/__init__", data=False)
161         result = BlackRunner().invoke(
162             black.main,
163             [
164                 "-",
165                 "--fast",
166                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
167                 f"--config={EMPTY_CONFIG}",
168             ],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         if source != result.output:
174             black.assert_equivalent(source, result.output)
175             black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("expression.py")
183         expected, _ = read_data("expression.diff")
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={EMPTY_CONFIG}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         args = [
202             "-",
203             "--fast",
204             f"--line-length={black.DEFAULT_LINE_LENGTH}",
205             "--diff",
206             "--color",
207             f"--config={EMPTY_CONFIG}",
208         ]
209         result = BlackRunner().invoke(
210             black.main, args, input=BytesIO(source.encode("utf8"))
211         )
212         actual = result.output
213         # Again, the contents are checked in a different test, so only look for colors.
214         self.assertIn("\033[1m", actual)
215         self.assertIn("\033[36m", actual)
216         self.assertIn("\033[32m", actual)
217         self.assertIn("\033[31m", actual)
218         self.assertIn("\033[0m", actual)
219
220     @patch("black.dump_to_file", dump_to_stderr)
221     def _test_wip(self) -> None:
222         source, expected = read_data("wip")
223         sys.settrace(tracefunc)
224         mode = replace(
225             DEFAULT_MODE,
226             experimental_string_processing=False,
227             target_versions={black.TargetVersion.PY38},
228         )
229         actual = fs(source, mode=mode)
230         sys.settrace(None)
231         self.assertFormatEqual(expected, actual)
232         black.assert_equivalent(source, actual)
233         black.assert_stable(source, actual, black.FileMode())
234
235     def test_pep_572_version_detection(self) -> None:
236         source, _ = read_data("pep_572")
237         root = black.lib2to3_parse(source)
238         features = black.get_features_used(root)
239         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
240         versions = black.detect_target_versions(root)
241         self.assertIn(black.TargetVersion.PY38, versions)
242
243     def test_expression_ff(self) -> None:
244         source, expected = read_data("expression")
245         tmp_file = Path(black.dump_to_file(source))
246         try:
247             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
248             with open(tmp_file, encoding="utf8") as f:
249                 actual = f.read()
250         finally:
251             os.unlink(tmp_file)
252         self.assertFormatEqual(expected, actual)
253         with patch("black.dump_to_file", dump_to_stderr):
254             black.assert_equivalent(source, actual)
255             black.assert_stable(source, actual, DEFAULT_MODE)
256
257     def test_expression_diff(self) -> None:
258         source, _ = read_data("expression.py")
259         expected, _ = read_data("expression.diff")
260         tmp_file = Path(black.dump_to_file(source))
261         diff_header = re.compile(
262             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
263             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
264         )
265         try:
266             result = BlackRunner().invoke(
267                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
268             )
269             self.assertEqual(result.exit_code, 0)
270         finally:
271             os.unlink(tmp_file)
272         actual = result.output
273         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
274         if expected != actual:
275             dump = black.dump_to_file(actual)
276             msg = (
277                 "Expected diff isn't equal to the actual. If you made changes to"
278                 " expression.py and this is an anticipated difference, overwrite"
279                 f" tests/data/expression.diff with {dump}"
280             )
281             self.assertEqual(expected, actual, msg)
282
283     def test_expression_diff_with_color(self) -> None:
284         source, _ = read_data("expression.py")
285         expected, _ = read_data("expression.diff")
286         tmp_file = Path(black.dump_to_file(source))
287         try:
288             result = BlackRunner().invoke(
289                 black.main,
290                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
291             )
292         finally:
293             os.unlink(tmp_file)
294         actual = result.output
295         # We check the contents of the diff in `test_expression_diff`. All
296         # we need to check here is that color codes exist in the result.
297         self.assertIn("\033[1m", actual)
298         self.assertIn("\033[36m", actual)
299         self.assertIn("\033[32m", actual)
300         self.assertIn("\033[31m", actual)
301         self.assertIn("\033[0m", actual)
302
303     def test_detect_pos_only_arguments(self) -> None:
304         source, _ = read_data("pep_570")
305         root = black.lib2to3_parse(source)
306         features = black.get_features_used(root)
307         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
308         versions = black.detect_target_versions(root)
309         self.assertIn(black.TargetVersion.PY38, versions)
310
311     @patch("black.dump_to_file", dump_to_stderr)
312     def test_string_quotes(self) -> None:
313         source, expected = read_data("string_quotes")
314         mode = black.Mode(preview=True)
315         assert_format(source, expected, mode)
316         mode = replace(mode, string_normalization=False)
317         not_normalized = fs(source, mode=mode)
318         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
319         black.assert_equivalent(source, not_normalized)
320         black.assert_stable(source, not_normalized, mode=mode)
321
322     def test_skip_magic_trailing_comma(self) -> None:
323         source, _ = read_data("expression.py")
324         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
325         tmp_file = Path(black.dump_to_file(source))
326         diff_header = re.compile(
327             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
328             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
329         )
330         try:
331             result = BlackRunner().invoke(
332                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
333             )
334             self.assertEqual(result.exit_code, 0)
335         finally:
336             os.unlink(tmp_file)
337         actual = result.output
338         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
339         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
340         if expected != actual:
341             dump = black.dump_to_file(actual)
342             msg = (
343                 "Expected diff isn't equal to the actual. If you made changes to"
344                 " expression.py and this is an anticipated difference, overwrite"
345                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
346             )
347             self.assertEqual(expected, actual, msg)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_async_as_identifier(self) -> None:
351         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
352         source, expected = read_data("async_as_identifier")
353         actual = fs(source)
354         self.assertFormatEqual(expected, actual)
355         major, minor = sys.version_info[:2]
356         if major < 3 or (major <= 3 and minor < 7):
357             black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, DEFAULT_MODE)
359         # ensure black can parse this when the target is 3.6
360         self.invokeBlack([str(source_path), "--target-version", "py36"])
361         # but not on 3.7, because async/await is no longer an identifier
362         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_python37(self) -> None:
366         source_path = (THIS_DIR / "data" / "python37.py").resolve()
367         source, expected = read_data("python37")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         major, minor = sys.version_info[:2]
371         if major > 3 or (major == 3 and minor >= 7):
372             black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, DEFAULT_MODE)
374         # ensure black can parse this when the target is 3.7
375         self.invokeBlack([str(source_path), "--target-version", "py37"])
376         # but not on 3.6, because we use async as a reserved keyword
377         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
378
379     def test_tab_comment_indentation(self) -> None:
380         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
381         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
382         self.assertFormatEqual(contents_spc, fs(contents_spc))
383         self.assertFormatEqual(contents_spc, fs(contents_tab))
384
385         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
386         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
387         self.assertFormatEqual(contents_spc, fs(contents_spc))
388         self.assertFormatEqual(contents_spc, fs(contents_tab))
389
390         # mixed tabs and spaces (valid Python 2 code)
391         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
392         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
393         self.assertFormatEqual(contents_spc, fs(contents_spc))
394         self.assertFormatEqual(contents_spc, fs(contents_tab))
395
396         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
397         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
398         self.assertFormatEqual(contents_spc, fs(contents_spc))
399         self.assertFormatEqual(contents_spc, fs(contents_tab))
400
401     def test_report_verbose(self) -> None:
402         report = Report(verbose=True)
403         out_lines = []
404         err_lines = []
405
406         def out(msg: str, **kwargs: Any) -> None:
407             out_lines.append(msg)
408
409         def err(msg: str, **kwargs: Any) -> None:
410             err_lines.append(msg)
411
412         with patch("black.output._out", out), patch("black.output._err", err):
413             report.done(Path("f1"), black.Changed.NO)
414             self.assertEqual(len(out_lines), 1)
415             self.assertEqual(len(err_lines), 0)
416             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
417             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
418             self.assertEqual(report.return_code, 0)
419             report.done(Path("f2"), black.Changed.YES)
420             self.assertEqual(len(out_lines), 2)
421             self.assertEqual(len(err_lines), 0)
422             self.assertEqual(out_lines[-1], "reformatted f2")
423             self.assertEqual(
424                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
425             )
426             report.done(Path("f3"), black.Changed.CACHED)
427             self.assertEqual(len(out_lines), 3)
428             self.assertEqual(len(err_lines), 0)
429             self.assertEqual(
430                 out_lines[-1], "f3 wasn't modified on disk since last run."
431             )
432             self.assertEqual(
433                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
434             )
435             self.assertEqual(report.return_code, 0)
436             report.check = True
437             self.assertEqual(report.return_code, 1)
438             report.check = False
439             report.failed(Path("e1"), "boom")
440             self.assertEqual(len(out_lines), 3)
441             self.assertEqual(len(err_lines), 1)
442             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
446                 " reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.done(Path("f3"), black.Changed.YES)
450             self.assertEqual(len(out_lines), 4)
451             self.assertEqual(len(err_lines), 1)
452             self.assertEqual(out_lines[-1], "reformatted f3")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
456                 " reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.failed(Path("e2"), "boom")
460             self.assertEqual(len(out_lines), 4)
461             self.assertEqual(len(err_lines), 2)
462             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
463             self.assertEqual(
464                 unstyle(str(report)),
465                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
466                 " reformat.",
467             )
468             self.assertEqual(report.return_code, 123)
469             report.path_ignored(Path("wat"), "no match")
470             self.assertEqual(len(out_lines), 5)
471             self.assertEqual(len(err_lines), 2)
472             self.assertEqual(out_lines[-1], "wat ignored: no match")
473             self.assertEqual(
474                 unstyle(str(report)),
475                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
476                 " reformat.",
477             )
478             self.assertEqual(report.return_code, 123)
479             report.done(Path("f4"), black.Changed.NO)
480             self.assertEqual(len(out_lines), 6)
481             self.assertEqual(len(err_lines), 2)
482             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
483             self.assertEqual(
484                 unstyle(str(report)),
485                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
486                 " reformat.",
487             )
488             self.assertEqual(report.return_code, 123)
489             report.check = True
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
493                 " would fail to reformat.",
494             )
495             report.check = False
496             report.diff = True
497             self.assertEqual(
498                 unstyle(str(report)),
499                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
500                 " would fail to reformat.",
501             )
502
503     def test_report_quiet(self) -> None:
504         report = Report(quiet=True)
505         out_lines = []
506         err_lines = []
507
508         def out(msg: str, **kwargs: Any) -> None:
509             out_lines.append(msg)
510
511         def err(msg: str, **kwargs: Any) -> None:
512             err_lines.append(msg)
513
514         with patch("black.output._out", out), patch("black.output._err", err):
515             report.done(Path("f1"), black.Changed.NO)
516             self.assertEqual(len(out_lines), 0)
517             self.assertEqual(len(err_lines), 0)
518             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
519             self.assertEqual(report.return_code, 0)
520             report.done(Path("f2"), black.Changed.YES)
521             self.assertEqual(len(out_lines), 0)
522             self.assertEqual(len(err_lines), 0)
523             self.assertEqual(
524                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
525             )
526             report.done(Path("f3"), black.Changed.CACHED)
527             self.assertEqual(len(out_lines), 0)
528             self.assertEqual(len(err_lines), 0)
529             self.assertEqual(
530                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
531             )
532             self.assertEqual(report.return_code, 0)
533             report.check = True
534             self.assertEqual(report.return_code, 1)
535             report.check = False
536             report.failed(Path("e1"), "boom")
537             self.assertEqual(len(out_lines), 0)
538             self.assertEqual(len(err_lines), 1)
539             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
540             self.assertEqual(
541                 unstyle(str(report)),
542                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
543                 " reformat.",
544             )
545             self.assertEqual(report.return_code, 123)
546             report.done(Path("f3"), black.Changed.YES)
547             self.assertEqual(len(out_lines), 0)
548             self.assertEqual(len(err_lines), 1)
549             self.assertEqual(
550                 unstyle(str(report)),
551                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
552                 " reformat.",
553             )
554             self.assertEqual(report.return_code, 123)
555             report.failed(Path("e2"), "boom")
556             self.assertEqual(len(out_lines), 0)
557             self.assertEqual(len(err_lines), 2)
558             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
562                 " reformat.",
563             )
564             self.assertEqual(report.return_code, 123)
565             report.path_ignored(Path("wat"), "no match")
566             self.assertEqual(len(out_lines), 0)
567             self.assertEqual(len(err_lines), 2)
568             self.assertEqual(
569                 unstyle(str(report)),
570                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
571                 " reformat.",
572             )
573             self.assertEqual(report.return_code, 123)
574             report.done(Path("f4"), black.Changed.NO)
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 2)
577             self.assertEqual(
578                 unstyle(str(report)),
579                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
580                 " reformat.",
581             )
582             self.assertEqual(report.return_code, 123)
583             report.check = True
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
587                 " would fail to reformat.",
588             )
589             report.check = False
590             report.diff = True
591             self.assertEqual(
592                 unstyle(str(report)),
593                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
594                 " would fail to reformat.",
595             )
596
597     def test_report_normal(self) -> None:
598         report = black.Report()
599         out_lines = []
600         err_lines = []
601
602         def out(msg: str, **kwargs: Any) -> None:
603             out_lines.append(msg)
604
605         def err(msg: str, **kwargs: Any) -> None:
606             err_lines.append(msg)
607
608         with patch("black.output._out", out), patch("black.output._err", err):
609             report.done(Path("f1"), black.Changed.NO)
610             self.assertEqual(len(out_lines), 0)
611             self.assertEqual(len(err_lines), 0)
612             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
613             self.assertEqual(report.return_code, 0)
614             report.done(Path("f2"), black.Changed.YES)
615             self.assertEqual(len(out_lines), 1)
616             self.assertEqual(len(err_lines), 0)
617             self.assertEqual(out_lines[-1], "reformatted f2")
618             self.assertEqual(
619                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
620             )
621             report.done(Path("f3"), black.Changed.CACHED)
622             self.assertEqual(len(out_lines), 1)
623             self.assertEqual(len(err_lines), 0)
624             self.assertEqual(out_lines[-1], "reformatted f2")
625             self.assertEqual(
626                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
627             )
628             self.assertEqual(report.return_code, 0)
629             report.check = True
630             self.assertEqual(report.return_code, 1)
631             report.check = False
632             report.failed(Path("e1"), "boom")
633             self.assertEqual(len(out_lines), 1)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
639                 " reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.done(Path("f3"), black.Changed.YES)
643             self.assertEqual(len(out_lines), 2)
644             self.assertEqual(len(err_lines), 1)
645             self.assertEqual(out_lines[-1], "reformatted f3")
646             self.assertEqual(
647                 unstyle(str(report)),
648                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
649                 " reformat.",
650             )
651             self.assertEqual(report.return_code, 123)
652             report.failed(Path("e2"), "boom")
653             self.assertEqual(len(out_lines), 2)
654             self.assertEqual(len(err_lines), 2)
655             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
656             self.assertEqual(
657                 unstyle(str(report)),
658                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
659                 " reformat.",
660             )
661             self.assertEqual(report.return_code, 123)
662             report.path_ignored(Path("wat"), "no match")
663             self.assertEqual(len(out_lines), 2)
664             self.assertEqual(len(err_lines), 2)
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
668                 " reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.done(Path("f4"), black.Changed.NO)
672             self.assertEqual(len(out_lines), 2)
673             self.assertEqual(len(err_lines), 2)
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
677                 " reformat.",
678             )
679             self.assertEqual(report.return_code, 123)
680             report.check = True
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
684                 " would fail to reformat.",
685             )
686             report.check = False
687             report.diff = True
688             self.assertEqual(
689                 unstyle(str(report)),
690                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
691                 " would fail to reformat.",
692             )
693
694     def test_lib2to3_parse(self) -> None:
695         with self.assertRaises(black.InvalidInput):
696             black.lib2to3_parse("invalid syntax")
697
698         straddling = "x + y"
699         black.lib2to3_parse(straddling)
700         black.lib2to3_parse(straddling, {TargetVersion.PY36})
701
702         py2_only = "print x"
703         with self.assertRaises(black.InvalidInput):
704             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
705
706         py3_only = "exec(x, end=y)"
707         black.lib2to3_parse(py3_only)
708         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
709
710     def test_get_features_used_decorator(self) -> None:
711         # Test the feature detection of new decorator syntax
712         # since this makes some test cases of test_get_features_used()
713         # fails if it fails, this is tested first so that a useful case
714         # is identified
715         simples, relaxed = read_data("decorators")
716         # skip explanation comments at the top of the file
717         for simple_test in simples.split("##")[1:]:
718             node = black.lib2to3_parse(simple_test)
719             decorator = str(node.children[0].children[0]).strip()
720             self.assertNotIn(
721                 Feature.RELAXED_DECORATORS,
722                 black.get_features_used(node),
723                 msg=(
724                     f"decorator '{decorator}' follows python<=3.8 syntax"
725                     "but is detected as 3.9+"
726                     # f"The full node is\n{node!r}"
727                 ),
728             )
729         # skip the '# output' comment at the top of the output part
730         for relaxed_test in relaxed.split("##")[1:]:
731             node = black.lib2to3_parse(relaxed_test)
732             decorator = str(node.children[0].children[0]).strip()
733             self.assertIn(
734                 Feature.RELAXED_DECORATORS,
735                 black.get_features_used(node),
736                 msg=(
737                     f"decorator '{decorator}' uses python3.9+ syntax"
738                     "but is detected as python<=3.8"
739                     # f"The full node is\n{node!r}"
740                 ),
741             )
742
743     def test_get_features_used(self) -> None:
744         node = black.lib2to3_parse("def f(*, arg): ...\n")
745         self.assertEqual(black.get_features_used(node), set())
746         node = black.lib2to3_parse("def f(*, arg,): ...\n")
747         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
748         node = black.lib2to3_parse("f(*arg,)\n")
749         self.assertEqual(
750             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
751         )
752         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
753         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
754         node = black.lib2to3_parse("123_456\n")
755         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
756         node = black.lib2to3_parse("123456\n")
757         self.assertEqual(black.get_features_used(node), set())
758         source, expected = read_data("function")
759         node = black.lib2to3_parse(source)
760         expected_features = {
761             Feature.TRAILING_COMMA_IN_CALL,
762             Feature.TRAILING_COMMA_IN_DEF,
763             Feature.F_STRINGS,
764         }
765         self.assertEqual(black.get_features_used(node), expected_features)
766         node = black.lib2to3_parse(expected)
767         self.assertEqual(black.get_features_used(node), expected_features)
768         source, expected = read_data("expression")
769         node = black.lib2to3_parse(source)
770         self.assertEqual(black.get_features_used(node), set())
771         node = black.lib2to3_parse(expected)
772         self.assertEqual(black.get_features_used(node), set())
773         node = black.lib2to3_parse("lambda a, /, b: ...")
774         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
775         node = black.lib2to3_parse("def fn(a, /, b): ...")
776         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
777         node = black.lib2to3_parse("def fn(): yield a, b")
778         self.assertEqual(black.get_features_used(node), set())
779         node = black.lib2to3_parse("def fn(): return a, b")
780         self.assertEqual(black.get_features_used(node), set())
781         node = black.lib2to3_parse("def fn(): yield *b, c")
782         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
783         node = black.lib2to3_parse("def fn(): return a, *b, c")
784         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
785         node = black.lib2to3_parse("x = a, *b, c")
786         self.assertEqual(black.get_features_used(node), set())
787         node = black.lib2to3_parse("x: Any = regular")
788         self.assertEqual(black.get_features_used(node), set())
789         node = black.lib2to3_parse("x: Any = (regular, regular)")
790         self.assertEqual(black.get_features_used(node), set())
791         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
792         self.assertEqual(black.get_features_used(node), set())
793         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
794         self.assertEqual(
795             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
796         )
797
798     def test_get_features_used_for_future_flags(self) -> None:
799         for src, features in [
800             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
801             (
802                 "from __future__ import (other, annotations)",
803                 {Feature.FUTURE_ANNOTATIONS},
804             ),
805             ("a = 1 + 2\nfrom something import annotations", set()),
806             ("from __future__ import x, y", set()),
807         ]:
808             with self.subTest(src=src, features=features):
809                 node = black.lib2to3_parse(src)
810                 future_imports = black.get_future_imports(node)
811                 self.assertEqual(
812                     black.get_features_used(node, future_imports=future_imports),
813                     features,
814                 )
815
816     def test_get_future_imports(self) -> None:
817         node = black.lib2to3_parse("\n")
818         self.assertEqual(set(), black.get_future_imports(node))
819         node = black.lib2to3_parse("from __future__ import black\n")
820         self.assertEqual({"black"}, black.get_future_imports(node))
821         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
822         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
823         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
824         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
825         node = black.lib2to3_parse(
826             "from __future__ import multiple\nfrom __future__ import imports\n"
827         )
828         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
829         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
830         self.assertEqual({"black"}, black.get_future_imports(node))
831         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
832         self.assertEqual({"black"}, black.get_future_imports(node))
833         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from some.module import black\n")
836         self.assertEqual(set(), black.get_future_imports(node))
837         node = black.lib2to3_parse(
838             "from __future__ import unicode_literals as _unicode_literals"
839         )
840         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import unicode_literals as _lol, print"
843         )
844         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
845
846     @pytest.mark.incompatible_with_mypyc
847     def test_debug_visitor(self) -> None:
848         source, _ = read_data("debug_visitor.py")
849         expected, _ = read_data("debug_visitor.out")
850         out_lines = []
851         err_lines = []
852
853         def out(msg: str, **kwargs: Any) -> None:
854             out_lines.append(msg)
855
856         def err(msg: str, **kwargs: Any) -> None:
857             err_lines.append(msg)
858
859         with patch("black.debug.out", out):
860             DebugVisitor.show(source)
861         actual = "\n".join(out_lines) + "\n"
862         log_name = ""
863         if expected != actual:
864             log_name = black.dump_to_file(*out_lines)
865         self.assertEqual(
866             expected,
867             actual,
868             f"AST print out is different. Actual version dumped to {log_name}",
869         )
870
871     def test_format_file_contents(self) -> None:
872         empty = ""
873         mode = DEFAULT_MODE
874         with self.assertRaises(black.NothingChanged):
875             black.format_file_contents(empty, mode=mode, fast=False)
876         just_nl = "\n"
877         with self.assertRaises(black.NothingChanged):
878             black.format_file_contents(just_nl, mode=mode, fast=False)
879         same = "j = [1, 2, 3]\n"
880         with self.assertRaises(black.NothingChanged):
881             black.format_file_contents(same, mode=mode, fast=False)
882         different = "j = [1,2,3]"
883         expected = same
884         actual = black.format_file_contents(different, mode=mode, fast=False)
885         self.assertEqual(expected, actual)
886         invalid = "return if you can"
887         with self.assertRaises(black.InvalidInput) as e:
888             black.format_file_contents(invalid, mode=mode, fast=False)
889         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
890
891     def test_endmarker(self) -> None:
892         n = black.lib2to3_parse("\n")
893         self.assertEqual(n.type, black.syms.file_input)
894         self.assertEqual(len(n.children), 1)
895         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
896
897     @pytest.mark.incompatible_with_mypyc
898     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
899     def test_assertFormatEqual(self) -> None:
900         out_lines = []
901         err_lines = []
902
903         def out(msg: str, **kwargs: Any) -> None:
904             out_lines.append(msg)
905
906         def err(msg: str, **kwargs: Any) -> None:
907             err_lines.append(msg)
908
909         with patch("black.output._out", out), patch("black.output._err", err):
910             with self.assertRaises(AssertionError):
911                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
912
913         out_str = "".join(out_lines)
914         self.assertIn("Expected tree:", out_str)
915         self.assertIn("Actual tree:", out_str)
916         self.assertEqual("".join(err_lines), "")
917
918     @event_loop()
919     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
920     def test_works_in_mono_process_only_environment(self) -> None:
921         with cache_dir() as workspace:
922             for f in [
923                 (workspace / "one.py").resolve(),
924                 (workspace / "two.py").resolve(),
925             ]:
926                 f.write_text('print("hello")\n')
927             self.invokeBlack([str(workspace)])
928
929     @event_loop()
930     def test_check_diff_use_together(self) -> None:
931         with cache_dir():
932             # Files which will be reformatted.
933             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
934             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
935             # Files which will not be reformatted.
936             src2 = (THIS_DIR / "data" / "composition.py").resolve()
937             self.invokeBlack([str(src2), "--diff", "--check"])
938             # Multi file command.
939             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
940
941     def test_no_src_fails(self) -> None:
942         with cache_dir():
943             self.invokeBlack([], exit_code=1)
944
945     def test_src_and_code_fails(self) -> None:
946         with cache_dir():
947             self.invokeBlack([".", "-c", "0"], exit_code=1)
948
949     def test_broken_symlink(self) -> None:
950         with cache_dir() as workspace:
951             symlink = workspace / "broken_link.py"
952             try:
953                 symlink.symlink_to("nonexistent.py")
954             except (OSError, NotImplementedError) as e:
955                 self.skipTest(f"Can't create symlinks: {e}")
956             self.invokeBlack([str(workspace.resolve())])
957
958     def test_single_file_force_pyi(self) -> None:
959         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
960         contents, expected = read_data("force_pyi")
961         with cache_dir() as workspace:
962             path = (workspace / "file.py").resolve()
963             with open(path, "w") as fh:
964                 fh.write(contents)
965             self.invokeBlack([str(path), "--pyi"])
966             with open(path, "r") as fh:
967                 actual = fh.read()
968             # verify cache with --pyi is separate
969             pyi_cache = black.read_cache(pyi_mode)
970             self.assertIn(str(path), pyi_cache)
971             normal_cache = black.read_cache(DEFAULT_MODE)
972             self.assertNotIn(str(path), normal_cache)
973         self.assertFormatEqual(expected, actual)
974         black.assert_equivalent(contents, actual)
975         black.assert_stable(contents, actual, pyi_mode)
976
977     @event_loop()
978     def test_multi_file_force_pyi(self) -> None:
979         reg_mode = DEFAULT_MODE
980         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
981         contents, expected = read_data("force_pyi")
982         with cache_dir() as workspace:
983             paths = [
984                 (workspace / "file1.py").resolve(),
985                 (workspace / "file2.py").resolve(),
986             ]
987             for path in paths:
988                 with open(path, "w") as fh:
989                     fh.write(contents)
990             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
991             for path in paths:
992                 with open(path, "r") as fh:
993                     actual = fh.read()
994                 self.assertEqual(actual, expected)
995             # verify cache with --pyi is separate
996             pyi_cache = black.read_cache(pyi_mode)
997             normal_cache = black.read_cache(reg_mode)
998             for path in paths:
999                 self.assertIn(str(path), pyi_cache)
1000                 self.assertNotIn(str(path), normal_cache)
1001
1002     def test_pipe_force_pyi(self) -> None:
1003         source, expected = read_data("force_pyi")
1004         result = CliRunner().invoke(
1005             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1006         )
1007         self.assertEqual(result.exit_code, 0)
1008         actual = result.output
1009         self.assertFormatEqual(actual, expected)
1010
1011     def test_single_file_force_py36(self) -> None:
1012         reg_mode = DEFAULT_MODE
1013         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1014         source, expected = read_data("force_py36")
1015         with cache_dir() as workspace:
1016             path = (workspace / "file.py").resolve()
1017             with open(path, "w") as fh:
1018                 fh.write(source)
1019             self.invokeBlack([str(path), *PY36_ARGS])
1020             with open(path, "r") as fh:
1021                 actual = fh.read()
1022             # verify cache with --target-version is separate
1023             py36_cache = black.read_cache(py36_mode)
1024             self.assertIn(str(path), py36_cache)
1025             normal_cache = black.read_cache(reg_mode)
1026             self.assertNotIn(str(path), normal_cache)
1027         self.assertEqual(actual, expected)
1028
1029     @event_loop()
1030     def test_multi_file_force_py36(self) -> None:
1031         reg_mode = DEFAULT_MODE
1032         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1033         source, expected = read_data("force_py36")
1034         with cache_dir() as workspace:
1035             paths = [
1036                 (workspace / "file1.py").resolve(),
1037                 (workspace / "file2.py").resolve(),
1038             ]
1039             for path in paths:
1040                 with open(path, "w") as fh:
1041                     fh.write(source)
1042             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1043             for path in paths:
1044                 with open(path, "r") as fh:
1045                     actual = fh.read()
1046                 self.assertEqual(actual, expected)
1047             # verify cache with --target-version is separate
1048             pyi_cache = black.read_cache(py36_mode)
1049             normal_cache = black.read_cache(reg_mode)
1050             for path in paths:
1051                 self.assertIn(str(path), pyi_cache)
1052                 self.assertNotIn(str(path), normal_cache)
1053
1054     def test_pipe_force_py36(self) -> None:
1055         source, expected = read_data("force_py36")
1056         result = CliRunner().invoke(
1057             black.main,
1058             ["-", "-q", "--target-version=py36"],
1059             input=BytesIO(source.encode("utf8")),
1060         )
1061         self.assertEqual(result.exit_code, 0)
1062         actual = result.output
1063         self.assertFormatEqual(actual, expected)
1064
1065     @pytest.mark.incompatible_with_mypyc
1066     def test_reformat_one_with_stdin(self) -> None:
1067         with patch(
1068             "black.format_stdin_to_stdout",
1069             return_value=lambda *args, **kwargs: black.Changed.YES,
1070         ) as fsts:
1071             report = MagicMock()
1072             path = Path("-")
1073             black.reformat_one(
1074                 path,
1075                 fast=True,
1076                 write_back=black.WriteBack.YES,
1077                 mode=DEFAULT_MODE,
1078                 report=report,
1079             )
1080             fsts.assert_called_once()
1081             report.done.assert_called_with(path, black.Changed.YES)
1082
1083     @pytest.mark.incompatible_with_mypyc
1084     def test_reformat_one_with_stdin_filename(self) -> None:
1085         with patch(
1086             "black.format_stdin_to_stdout",
1087             return_value=lambda *args, **kwargs: black.Changed.YES,
1088         ) as fsts:
1089             report = MagicMock()
1090             p = "foo.py"
1091             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1092             expected = Path(p)
1093             black.reformat_one(
1094                 path,
1095                 fast=True,
1096                 write_back=black.WriteBack.YES,
1097                 mode=DEFAULT_MODE,
1098                 report=report,
1099             )
1100             fsts.assert_called_once_with(
1101                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1102             )
1103             # __BLACK_STDIN_FILENAME__ should have been stripped
1104             report.done.assert_called_with(expected, black.Changed.YES)
1105
1106     @pytest.mark.incompatible_with_mypyc
1107     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1108         with patch(
1109             "black.format_stdin_to_stdout",
1110             return_value=lambda *args, **kwargs: black.Changed.YES,
1111         ) as fsts:
1112             report = MagicMock()
1113             p = "foo.pyi"
1114             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1115             expected = Path(p)
1116             black.reformat_one(
1117                 path,
1118                 fast=True,
1119                 write_back=black.WriteBack.YES,
1120                 mode=DEFAULT_MODE,
1121                 report=report,
1122             )
1123             fsts.assert_called_once_with(
1124                 fast=True,
1125                 write_back=black.WriteBack.YES,
1126                 mode=replace(DEFAULT_MODE, is_pyi=True),
1127             )
1128             # __BLACK_STDIN_FILENAME__ should have been stripped
1129             report.done.assert_called_with(expected, black.Changed.YES)
1130
1131     @pytest.mark.incompatible_with_mypyc
1132     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1133         with patch(
1134             "black.format_stdin_to_stdout",
1135             return_value=lambda *args, **kwargs: black.Changed.YES,
1136         ) as fsts:
1137             report = MagicMock()
1138             p = "foo.ipynb"
1139             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1140             expected = Path(p)
1141             black.reformat_one(
1142                 path,
1143                 fast=True,
1144                 write_back=black.WriteBack.YES,
1145                 mode=DEFAULT_MODE,
1146                 report=report,
1147             )
1148             fsts.assert_called_once_with(
1149                 fast=True,
1150                 write_back=black.WriteBack.YES,
1151                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1152             )
1153             # __BLACK_STDIN_FILENAME__ should have been stripped
1154             report.done.assert_called_with(expected, black.Changed.YES)
1155
1156     @pytest.mark.incompatible_with_mypyc
1157     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1158         with patch(
1159             "black.format_stdin_to_stdout",
1160             return_value=lambda *args, **kwargs: black.Changed.YES,
1161         ) as fsts:
1162             report = MagicMock()
1163             # Even with an existing file, since we are forcing stdin, black
1164             # should output to stdout and not modify the file inplace
1165             p = Path(str(THIS_DIR / "data/collections.py"))
1166             # Make sure is_file actually returns True
1167             self.assertTrue(p.is_file())
1168             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1169             expected = Path(p)
1170             black.reformat_one(
1171                 path,
1172                 fast=True,
1173                 write_back=black.WriteBack.YES,
1174                 mode=DEFAULT_MODE,
1175                 report=report,
1176             )
1177             fsts.assert_called_once()
1178             # __BLACK_STDIN_FILENAME__ should have been stripped
1179             report.done.assert_called_with(expected, black.Changed.YES)
1180
1181     def test_reformat_one_with_stdin_empty(self) -> None:
1182         output = io.StringIO()
1183         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1184             try:
1185                 black.format_stdin_to_stdout(
1186                     fast=True,
1187                     content="",
1188                     write_back=black.WriteBack.YES,
1189                     mode=DEFAULT_MODE,
1190                 )
1191             except io.UnsupportedOperation:
1192                 pass  # StringIO does not support detach
1193             assert output.getvalue() == ""
1194
1195     def test_invalid_cli_regex(self) -> None:
1196         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1197             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1198
1199     def test_required_version_matches_version(self) -> None:
1200         self.invokeBlack(
1201             ["--required-version", black.__version__, "-c", "0"],
1202             exit_code=0,
1203             ignore_config=True,
1204         )
1205
1206     def test_required_version_matches_partial_version(self) -> None:
1207         self.invokeBlack(
1208             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1209             exit_code=0,
1210             ignore_config=True,
1211         )
1212
1213     def test_required_version_does_not_match_on_minor_version(self) -> None:
1214         self.invokeBlack(
1215             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1216             exit_code=1,
1217             ignore_config=True,
1218         )
1219
1220     def test_required_version_does_not_match_version(self) -> None:
1221         result = BlackRunner().invoke(
1222             black.main,
1223             ["--required-version", "20.99b", "-c", "0"],
1224         )
1225         self.assertEqual(result.exit_code, 1)
1226         self.assertIn("required version", result.stderr)
1227
1228     def test_preserves_line_endings(self) -> None:
1229         with TemporaryDirectory() as workspace:
1230             test_file = Path(workspace) / "test.py"
1231             for nl in ["\n", "\r\n"]:
1232                 contents = nl.join(["def f(  ):", "    pass"])
1233                 test_file.write_bytes(contents.encode())
1234                 ff(test_file, write_back=black.WriteBack.YES)
1235                 updated_contents: bytes = test_file.read_bytes()
1236                 self.assertIn(nl.encode(), updated_contents)
1237                 if nl == "\n":
1238                     self.assertNotIn(b"\r\n", updated_contents)
1239
1240     def test_preserves_line_endings_via_stdin(self) -> None:
1241         for nl in ["\n", "\r\n"]:
1242             contents = nl.join(["def f(  ):", "    pass"])
1243             runner = BlackRunner()
1244             result = runner.invoke(
1245                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1246             )
1247             self.assertEqual(result.exit_code, 0)
1248             output = result.stdout_bytes
1249             self.assertIn(nl.encode("utf8"), output)
1250             if nl == "\n":
1251                 self.assertNotIn(b"\r\n", output)
1252
1253     def test_assert_equivalent_different_asts(self) -> None:
1254         with self.assertRaises(AssertionError):
1255             black.assert_equivalent("{}", "None")
1256
1257     def test_shhh_click(self) -> None:
1258         try:
1259             from click import _unicodefun
1260         except ModuleNotFoundError:
1261             self.skipTest("Incompatible Click version")
1262         if not hasattr(_unicodefun, "_verify_python3_env"):
1263             self.skipTest("Incompatible Click version")
1264         # First, let's see if Click is crashing with a preferred ASCII charset.
1265         with patch("locale.getpreferredencoding") as gpe:
1266             gpe.return_value = "ASCII"
1267             with self.assertRaises(RuntimeError):
1268                 _unicodefun._verify_python3_env()  # type: ignore
1269         # Now, let's silence Click...
1270         black.patch_click()
1271         # ...and confirm it's silent.
1272         with patch("locale.getpreferredencoding") as gpe:
1273             gpe.return_value = "ASCII"
1274             try:
1275                 _unicodefun._verify_python3_env()  # type: ignore
1276             except RuntimeError as re:
1277                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1278
1279     def test_root_logger_not_used_directly(self) -> None:
1280         def fail(*args: Any, **kwargs: Any) -> None:
1281             self.fail("Record created with root logger")
1282
1283         with patch.multiple(
1284             logging.root,
1285             debug=fail,
1286             info=fail,
1287             warning=fail,
1288             error=fail,
1289             critical=fail,
1290             log=fail,
1291         ):
1292             ff(THIS_DIR / "util.py")
1293
1294     def test_invalid_config_return_code(self) -> None:
1295         tmp_file = Path(black.dump_to_file())
1296         try:
1297             tmp_config = Path(black.dump_to_file())
1298             tmp_config.unlink()
1299             args = ["--config", str(tmp_config), str(tmp_file)]
1300             self.invokeBlack(args, exit_code=2, ignore_config=False)
1301         finally:
1302             tmp_file.unlink()
1303
1304     def test_parse_pyproject_toml(self) -> None:
1305         test_toml_file = THIS_DIR / "test.toml"
1306         config = black.parse_pyproject_toml(str(test_toml_file))
1307         self.assertEqual(config["verbose"], 1)
1308         self.assertEqual(config["check"], "no")
1309         self.assertEqual(config["diff"], "y")
1310         self.assertEqual(config["color"], True)
1311         self.assertEqual(config["line_length"], 79)
1312         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1313         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1314         self.assertEqual(config["exclude"], r"\.pyi?$")
1315         self.assertEqual(config["include"], r"\.py?$")
1316
1317     def test_read_pyproject_toml(self) -> None:
1318         test_toml_file = THIS_DIR / "test.toml"
1319         fake_ctx = FakeContext()
1320         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1321         config = fake_ctx.default_map
1322         self.assertEqual(config["verbose"], "1")
1323         self.assertEqual(config["check"], "no")
1324         self.assertEqual(config["diff"], "y")
1325         self.assertEqual(config["color"], "True")
1326         self.assertEqual(config["line_length"], "79")
1327         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1328         self.assertEqual(config["exclude"], r"\.pyi?$")
1329         self.assertEqual(config["include"], r"\.py?$")
1330
1331     @pytest.mark.incompatible_with_mypyc
1332     def test_find_project_root(self) -> None:
1333         with TemporaryDirectory() as workspace:
1334             root = Path(workspace)
1335             test_dir = root / "test"
1336             test_dir.mkdir()
1337
1338             src_dir = root / "src"
1339             src_dir.mkdir()
1340
1341             root_pyproject = root / "pyproject.toml"
1342             root_pyproject.touch()
1343             src_pyproject = src_dir / "pyproject.toml"
1344             src_pyproject.touch()
1345             src_python = src_dir / "foo.py"
1346             src_python.touch()
1347
1348             self.assertEqual(
1349                 black.find_project_root((src_dir, test_dir)),
1350                 (root.resolve(), "pyproject.toml"),
1351             )
1352             self.assertEqual(
1353                 black.find_project_root((src_dir,)),
1354                 (src_dir.resolve(), "pyproject.toml"),
1355             )
1356             self.assertEqual(
1357                 black.find_project_root((src_python,)),
1358                 (src_dir.resolve(), "pyproject.toml"),
1359             )
1360
1361     @patch(
1362         "black.files.find_user_pyproject_toml",
1363     )
1364     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1365         find_user_pyproject_toml.side_effect = RuntimeError()
1366
1367         with redirect_stderr(io.StringIO()) as stderr:
1368             result = black.files.find_pyproject_toml(
1369                 path_search_start=(str(Path.cwd().root),)
1370             )
1371
1372         assert result is None
1373         err = stderr.getvalue()
1374         assert "Ignoring user configuration" in err
1375
1376     @patch(
1377         "black.files.find_user_pyproject_toml",
1378         black.files.find_user_pyproject_toml.__wrapped__,
1379     )
1380     def test_find_user_pyproject_toml_linux(self) -> None:
1381         if system() == "Windows":
1382             return
1383
1384         # Test if XDG_CONFIG_HOME is checked
1385         with TemporaryDirectory() as workspace:
1386             tmp_user_config = Path(workspace) / "black"
1387             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1388                 self.assertEqual(
1389                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1390                 )
1391
1392         # Test fallback for XDG_CONFIG_HOME
1393         with patch.dict("os.environ"):
1394             os.environ.pop("XDG_CONFIG_HOME", None)
1395             fallback_user_config = Path("~/.config").expanduser() / "black"
1396             self.assertEqual(
1397                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1398             )
1399
1400     def test_find_user_pyproject_toml_windows(self) -> None:
1401         if system() != "Windows":
1402             return
1403
1404         user_config_path = Path.home() / ".black"
1405         self.assertEqual(
1406             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1407         )
1408
1409     def test_bpo_33660_workaround(self) -> None:
1410         if system() == "Windows":
1411             return
1412
1413         # https://bugs.python.org/issue33660
1414         root = Path("/")
1415         with change_directory(root):
1416             path = Path("workspace") / "project"
1417             report = black.Report(verbose=True)
1418             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1419             self.assertEqual(normalized_path, "workspace/project")
1420
1421     def test_newline_comment_interaction(self) -> None:
1422         source = "class A:\\\r\n# type: ignore\n pass\n"
1423         output = black.format_str(source, mode=DEFAULT_MODE)
1424         black.assert_stable(source, output, mode=DEFAULT_MODE)
1425
1426     def test_bpo_2142_workaround(self) -> None:
1427
1428         # https://bugs.python.org/issue2142
1429
1430         source, _ = read_data("missing_final_newline.py")
1431         # read_data adds a trailing newline
1432         source = source.rstrip()
1433         expected, _ = read_data("missing_final_newline.diff")
1434         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1435         diff_header = re.compile(
1436             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1437             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1438         )
1439         try:
1440             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1441             self.assertEqual(result.exit_code, 0)
1442         finally:
1443             os.unlink(tmp_file)
1444         actual = result.output
1445         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1446         self.assertEqual(actual, expected)
1447
1448     @staticmethod
1449     def compare_results(
1450         result: click.testing.Result, expected_value: str, expected_exit_code: int
1451     ) -> None:
1452         """Helper method to test the value and exit code of a click Result."""
1453         assert (
1454             result.output == expected_value
1455         ), "The output did not match the expected value."
1456         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1457
1458     def test_code_option(self) -> None:
1459         """Test the code option with no changes."""
1460         code = 'print("Hello world")\n'
1461         args = ["--code", code]
1462         result = CliRunner().invoke(black.main, args)
1463
1464         self.compare_results(result, code, 0)
1465
1466     def test_code_option_changed(self) -> None:
1467         """Test the code option when changes are required."""
1468         code = "print('hello world')"
1469         formatted = black.format_str(code, mode=DEFAULT_MODE)
1470
1471         args = ["--code", code]
1472         result = CliRunner().invoke(black.main, args)
1473
1474         self.compare_results(result, formatted, 0)
1475
1476     def test_code_option_check(self) -> None:
1477         """Test the code option when check is passed."""
1478         args = ["--check", "--code", 'print("Hello world")\n']
1479         result = CliRunner().invoke(black.main, args)
1480         self.compare_results(result, "", 0)
1481
1482     def test_code_option_check_changed(self) -> None:
1483         """Test the code option when changes are required, and check is passed."""
1484         args = ["--check", "--code", "print('hello world')"]
1485         result = CliRunner().invoke(black.main, args)
1486         self.compare_results(result, "", 1)
1487
1488     def test_code_option_diff(self) -> None:
1489         """Test the code option when diff is passed."""
1490         code = "print('hello world')"
1491         formatted = black.format_str(code, mode=DEFAULT_MODE)
1492         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1493
1494         args = ["--diff", "--code", code]
1495         result = CliRunner().invoke(black.main, args)
1496
1497         # Remove time from diff
1498         output = DIFF_TIME.sub("", result.output)
1499
1500         assert output == result_diff, "The output did not match the expected value."
1501         assert result.exit_code == 0, "The exit code is incorrect."
1502
1503     def test_code_option_color_diff(self) -> None:
1504         """Test the code option when color and diff are passed."""
1505         code = "print('hello world')"
1506         formatted = black.format_str(code, mode=DEFAULT_MODE)
1507
1508         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1509         result_diff = color_diff(result_diff)
1510
1511         args = ["--diff", "--color", "--code", code]
1512         result = CliRunner().invoke(black.main, args)
1513
1514         # Remove time from diff
1515         output = DIFF_TIME.sub("", result.output)
1516
1517         assert output == result_diff, "The output did not match the expected value."
1518         assert result.exit_code == 0, "The exit code is incorrect."
1519
1520     @pytest.mark.incompatible_with_mypyc
1521     def test_code_option_safe(self) -> None:
1522         """Test that the code option throws an error when the sanity checks fail."""
1523         # Patch black.assert_equivalent to ensure the sanity checks fail
1524         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1525             code = 'print("Hello world")'
1526             error_msg = f"{code}\nerror: cannot format <string>: \n"
1527
1528             args = ["--safe", "--code", code]
1529             result = CliRunner().invoke(black.main, args)
1530
1531             self.compare_results(result, error_msg, 123)
1532
1533     def test_code_option_fast(self) -> None:
1534         """Test that the code option ignores errors when the sanity checks fail."""
1535         # Patch black.assert_equivalent to ensure the sanity checks fail
1536         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1537             code = 'print("Hello world")'
1538             formatted = black.format_str(code, mode=DEFAULT_MODE)
1539
1540             args = ["--fast", "--code", code]
1541             result = CliRunner().invoke(black.main, args)
1542
1543             self.compare_results(result, formatted, 0)
1544
1545     @pytest.mark.incompatible_with_mypyc
1546     def test_code_option_config(self) -> None:
1547         """
1548         Test that the code option finds the pyproject.toml in the current directory.
1549         """
1550         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1551             args = ["--code", "print"]
1552             # This is the only directory known to contain a pyproject.toml
1553             with change_directory(PROJECT_ROOT):
1554                 CliRunner().invoke(black.main, args)
1555                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1556
1557             assert (
1558                 len(parse.mock_calls) >= 1
1559             ), "Expected config parse to be called with the current directory."
1560
1561             _, call_args, _ = parse.mock_calls[0]
1562             assert (
1563                 call_args[0].lower() == str(pyproject_path).lower()
1564             ), "Incorrect config loaded."
1565
1566     @pytest.mark.incompatible_with_mypyc
1567     def test_code_option_parent_config(self) -> None:
1568         """
1569         Test that the code option finds the pyproject.toml in the parent directory.
1570         """
1571         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1572             with change_directory(THIS_DIR):
1573                 args = ["--code", "print"]
1574                 CliRunner().invoke(black.main, args)
1575
1576                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1577                 assert (
1578                     len(parse.mock_calls) >= 1
1579                 ), "Expected config parse to be called with the current directory."
1580
1581                 _, call_args, _ = parse.mock_calls[0]
1582                 assert (
1583                     call_args[0].lower() == str(pyproject_path).lower()
1584                 ), "Incorrect config loaded."
1585
1586     def test_for_handled_unexpected_eof_error(self) -> None:
1587         """
1588         Test that an unexpected EOF SyntaxError is nicely presented.
1589         """
1590         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1591             black.lib2to3_parse("print(", {})
1592
1593         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1594
1595     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1596         with pytest.raises(AssertionError) as err:
1597             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1598
1599         err.match("--safe")
1600         # Unfortunately the SyntaxError message has changed in newer versions so we
1601         # can't match it directly.
1602         err.match("invalid character")
1603         err.match(r"\(<unknown>, line 1\)")
1604
1605
1606 class TestCaching:
1607     def test_get_cache_dir(
1608         self,
1609         tmp_path: Path,
1610         monkeypatch: pytest.MonkeyPatch,
1611     ) -> None:
1612         # Create multiple cache directories
1613         workspace1 = tmp_path / "ws1"
1614         workspace1.mkdir()
1615         workspace2 = tmp_path / "ws2"
1616         workspace2.mkdir()
1617
1618         # Force user_cache_dir to use the temporary directory for easier assertions
1619         patch_user_cache_dir = patch(
1620             target="black.cache.user_cache_dir",
1621             autospec=True,
1622             return_value=str(workspace1),
1623         )
1624
1625         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1626         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1627         with patch_user_cache_dir:
1628             assert get_cache_dir() == workspace1
1629
1630         # If it is set, use the path provided in the env var.
1631         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1632         assert get_cache_dir() == workspace2
1633
1634     def test_cache_broken_file(self) -> None:
1635         mode = DEFAULT_MODE
1636         with cache_dir() as workspace:
1637             cache_file = get_cache_file(mode)
1638             cache_file.write_text("this is not a pickle")
1639             assert black.read_cache(mode) == {}
1640             src = (workspace / "test.py").resolve()
1641             src.write_text("print('hello')")
1642             invokeBlack([str(src)])
1643             cache = black.read_cache(mode)
1644             assert str(src) in cache
1645
1646     def test_cache_single_file_already_cached(self) -> None:
1647         mode = DEFAULT_MODE
1648         with cache_dir() as workspace:
1649             src = (workspace / "test.py").resolve()
1650             src.write_text("print('hello')")
1651             black.write_cache({}, [src], mode)
1652             invokeBlack([str(src)])
1653             assert src.read_text() == "print('hello')"
1654
1655     @event_loop()
1656     def test_cache_multiple_files(self) -> None:
1657         mode = DEFAULT_MODE
1658         with cache_dir() as workspace, patch(
1659             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1660         ):
1661             one = (workspace / "one.py").resolve()
1662             with one.open("w") as fobj:
1663                 fobj.write("print('hello')")
1664             two = (workspace / "two.py").resolve()
1665             with two.open("w") as fobj:
1666                 fobj.write("print('hello')")
1667             black.write_cache({}, [one], mode)
1668             invokeBlack([str(workspace)])
1669             with one.open("r") as fobj:
1670                 assert fobj.read() == "print('hello')"
1671             with two.open("r") as fobj:
1672                 assert fobj.read() == 'print("hello")\n'
1673             cache = black.read_cache(mode)
1674             assert str(one) in cache
1675             assert str(two) in cache
1676
1677     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1678     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1679         mode = DEFAULT_MODE
1680         with cache_dir() as workspace:
1681             src = (workspace / "test.py").resolve()
1682             with src.open("w") as fobj:
1683                 fobj.write("print('hello')")
1684             with patch("black.read_cache") as read_cache, patch(
1685                 "black.write_cache"
1686             ) as write_cache:
1687                 cmd = [str(src), "--diff"]
1688                 if color:
1689                     cmd.append("--color")
1690                 invokeBlack(cmd)
1691                 cache_file = get_cache_file(mode)
1692                 assert cache_file.exists() is False
1693                 write_cache.assert_not_called()
1694                 read_cache.assert_not_called()
1695
1696     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1697     @event_loop()
1698     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1699         with cache_dir() as workspace:
1700             for tag in range(0, 4):
1701                 src = (workspace / f"test{tag}.py").resolve()
1702                 with src.open("w") as fobj:
1703                     fobj.write("print('hello')")
1704             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1705                 cmd = ["--diff", str(workspace)]
1706                 if color:
1707                     cmd.append("--color")
1708                 invokeBlack(cmd, exit_code=0)
1709                 # this isn't quite doing what we want, but if it _isn't_
1710                 # called then we cannot be using the lock it provides
1711                 mgr.assert_called()
1712
1713     def test_no_cache_when_stdin(self) -> None:
1714         mode = DEFAULT_MODE
1715         with cache_dir():
1716             result = CliRunner().invoke(
1717                 black.main, ["-"], input=BytesIO(b"print('hello')")
1718             )
1719             assert not result.exit_code
1720             cache_file = get_cache_file(mode)
1721             assert not cache_file.exists()
1722
1723     def test_read_cache_no_cachefile(self) -> None:
1724         mode = DEFAULT_MODE
1725         with cache_dir():
1726             assert black.read_cache(mode) == {}
1727
1728     def test_write_cache_read_cache(self) -> None:
1729         mode = DEFAULT_MODE
1730         with cache_dir() as workspace:
1731             src = (workspace / "test.py").resolve()
1732             src.touch()
1733             black.write_cache({}, [src], mode)
1734             cache = black.read_cache(mode)
1735             assert str(src) in cache
1736             assert cache[str(src)] == black.get_cache_info(src)
1737
1738     def test_filter_cached(self) -> None:
1739         with TemporaryDirectory() as workspace:
1740             path = Path(workspace)
1741             uncached = (path / "uncached").resolve()
1742             cached = (path / "cached").resolve()
1743             cached_but_changed = (path / "changed").resolve()
1744             uncached.touch()
1745             cached.touch()
1746             cached_but_changed.touch()
1747             cache = {
1748                 str(cached): black.get_cache_info(cached),
1749                 str(cached_but_changed): (0.0, 0),
1750             }
1751             todo, done = black.filter_cached(
1752                 cache, {uncached, cached, cached_but_changed}
1753             )
1754             assert todo == {uncached, cached_but_changed}
1755             assert done == {cached}
1756
1757     def test_write_cache_creates_directory_if_needed(self) -> None:
1758         mode = DEFAULT_MODE
1759         with cache_dir(exists=False) as workspace:
1760             assert not workspace.exists()
1761             black.write_cache({}, [], mode)
1762             assert workspace.exists()
1763
1764     @event_loop()
1765     def test_failed_formatting_does_not_get_cached(self) -> None:
1766         mode = DEFAULT_MODE
1767         with cache_dir() as workspace, patch(
1768             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1769         ):
1770             failing = (workspace / "failing.py").resolve()
1771             with failing.open("w") as fobj:
1772                 fobj.write("not actually python")
1773             clean = (workspace / "clean.py").resolve()
1774             with clean.open("w") as fobj:
1775                 fobj.write('print("hello")\n')
1776             invokeBlack([str(workspace)], exit_code=123)
1777             cache = black.read_cache(mode)
1778             assert str(failing) not in cache
1779             assert str(clean) in cache
1780
1781     def test_write_cache_write_fail(self) -> None:
1782         mode = DEFAULT_MODE
1783         with cache_dir(), patch.object(Path, "open") as mock:
1784             mock.side_effect = OSError
1785             black.write_cache({}, [], mode)
1786
1787     def test_read_cache_line_lengths(self) -> None:
1788         mode = DEFAULT_MODE
1789         short_mode = replace(DEFAULT_MODE, line_length=1)
1790         with cache_dir() as workspace:
1791             path = (workspace / "file.py").resolve()
1792             path.touch()
1793             black.write_cache({}, [path], mode)
1794             one = black.read_cache(mode)
1795             assert str(path) in one
1796             two = black.read_cache(short_mode)
1797             assert str(path) not in two
1798
1799
1800 def assert_collected_sources(
1801     src: Sequence[Union[str, Path]],
1802     expected: Sequence[Union[str, Path]],
1803     *,
1804     ctx: Optional[FakeContext] = None,
1805     exclude: Optional[str] = None,
1806     include: Optional[str] = None,
1807     extend_exclude: Optional[str] = None,
1808     force_exclude: Optional[str] = None,
1809     stdin_filename: Optional[str] = None,
1810 ) -> None:
1811     gs_src = tuple(str(Path(s)) for s in src)
1812     gs_expected = [Path(s) for s in expected]
1813     gs_exclude = None if exclude is None else compile_pattern(exclude)
1814     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1815     gs_extend_exclude = (
1816         None if extend_exclude is None else compile_pattern(extend_exclude)
1817     )
1818     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1819     collected = black.get_sources(
1820         ctx=ctx or FakeContext(),
1821         src=gs_src,
1822         quiet=False,
1823         verbose=False,
1824         include=gs_include,
1825         exclude=gs_exclude,
1826         extend_exclude=gs_extend_exclude,
1827         force_exclude=gs_force_exclude,
1828         report=black.Report(),
1829         stdin_filename=stdin_filename,
1830     )
1831     assert sorted(collected) == sorted(gs_expected)
1832
1833
1834 class TestFileCollection:
1835     def test_include_exclude(self) -> None:
1836         path = THIS_DIR / "data" / "include_exclude_tests"
1837         src = [path]
1838         expected = [
1839             Path(path / "b/dont_exclude/a.py"),
1840             Path(path / "b/dont_exclude/a.pyi"),
1841         ]
1842         assert_collected_sources(
1843             src,
1844             expected,
1845             include=r"\.pyi?$",
1846             exclude=r"/exclude/|/\.definitely_exclude/",
1847         )
1848
1849     def test_gitignore_used_as_default(self) -> None:
1850         base = Path(DATA_DIR / "include_exclude_tests")
1851         expected = [
1852             base / "b/.definitely_exclude/a.py",
1853             base / "b/.definitely_exclude/a.pyi",
1854         ]
1855         src = [base / "b/"]
1856         ctx = FakeContext()
1857         ctx.obj["root"] = base
1858         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1859
1860     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1861     def test_exclude_for_issue_1572(self) -> None:
1862         # Exclude shouldn't touch files that were explicitly given to Black through the
1863         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1864         # https://github.com/psf/black/issues/1572
1865         path = DATA_DIR / "include_exclude_tests"
1866         src = [path / "b/exclude/a.py"]
1867         expected = [path / "b/exclude/a.py"]
1868         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1869
1870     def test_gitignore_exclude(self) -> None:
1871         path = THIS_DIR / "data" / "include_exclude_tests"
1872         include = re.compile(r"\.pyi?$")
1873         exclude = re.compile(r"")
1874         report = black.Report()
1875         gitignore = PathSpec.from_lines(
1876             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1877         )
1878         sources: List[Path] = []
1879         expected = [
1880             Path(path / "b/dont_exclude/a.py"),
1881             Path(path / "b/dont_exclude/a.pyi"),
1882         ]
1883         this_abs = THIS_DIR.resolve()
1884         sources.extend(
1885             black.gen_python_files(
1886                 path.iterdir(),
1887                 this_abs,
1888                 include,
1889                 exclude,
1890                 None,
1891                 None,
1892                 report,
1893                 gitignore,
1894                 verbose=False,
1895                 quiet=False,
1896             )
1897         )
1898         assert sorted(expected) == sorted(sources)
1899
1900     def test_nested_gitignore(self) -> None:
1901         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1902         include = re.compile(r"\.pyi?$")
1903         exclude = re.compile(r"")
1904         root_gitignore = black.files.get_gitignore(path)
1905         report = black.Report()
1906         expected: List[Path] = [
1907             Path(path / "x.py"),
1908             Path(path / "root/b.py"),
1909             Path(path / "root/c.py"),
1910             Path(path / "root/child/c.py"),
1911         ]
1912         this_abs = THIS_DIR.resolve()
1913         sources = list(
1914             black.gen_python_files(
1915                 path.iterdir(),
1916                 this_abs,
1917                 include,
1918                 exclude,
1919                 None,
1920                 None,
1921                 report,
1922                 root_gitignore,
1923                 verbose=False,
1924                 quiet=False,
1925             )
1926         )
1927         assert sorted(expected) == sorted(sources)
1928
1929     def test_invalid_gitignore(self) -> None:
1930         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1931         empty_config = path / "pyproject.toml"
1932         result = BlackRunner().invoke(
1933             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1934         )
1935         assert result.exit_code == 1
1936         assert result.stderr_bytes is not None
1937
1938         gitignore = path / ".gitignore"
1939         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1940
1941     def test_invalid_nested_gitignore(self) -> None:
1942         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1943         empty_config = path / "pyproject.toml"
1944         result = BlackRunner().invoke(
1945             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1946         )
1947         assert result.exit_code == 1
1948         assert result.stderr_bytes is not None
1949
1950         gitignore = path / "a" / ".gitignore"
1951         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1952
1953     def test_empty_include(self) -> None:
1954         path = DATA_DIR / "include_exclude_tests"
1955         src = [path]
1956         expected = [
1957             Path(path / "b/exclude/a.pie"),
1958             Path(path / "b/exclude/a.py"),
1959             Path(path / "b/exclude/a.pyi"),
1960             Path(path / "b/dont_exclude/a.pie"),
1961             Path(path / "b/dont_exclude/a.py"),
1962             Path(path / "b/dont_exclude/a.pyi"),
1963             Path(path / "b/.definitely_exclude/a.pie"),
1964             Path(path / "b/.definitely_exclude/a.py"),
1965             Path(path / "b/.definitely_exclude/a.pyi"),
1966             Path(path / ".gitignore"),
1967             Path(path / "pyproject.toml"),
1968         ]
1969         # Setting exclude explicitly to an empty string to block .gitignore usage.
1970         assert_collected_sources(src, expected, include="", exclude="")
1971
1972     def test_extend_exclude(self) -> None:
1973         path = DATA_DIR / "include_exclude_tests"
1974         src = [path]
1975         expected = [
1976             Path(path / "b/exclude/a.py"),
1977             Path(path / "b/dont_exclude/a.py"),
1978         ]
1979         assert_collected_sources(
1980             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1981         )
1982
1983     @pytest.mark.incompatible_with_mypyc
1984     def test_symlink_out_of_root_directory(self) -> None:
1985         path = MagicMock()
1986         root = THIS_DIR.resolve()
1987         child = MagicMock()
1988         include = re.compile(black.DEFAULT_INCLUDES)
1989         exclude = re.compile(black.DEFAULT_EXCLUDES)
1990         report = black.Report()
1991         gitignore = PathSpec.from_lines("gitwildmatch", [])
1992         # `child` should behave like a symlink which resolved path is clearly
1993         # outside of the `root` directory.
1994         path.iterdir.return_value = [child]
1995         child.resolve.return_value = Path("/a/b/c")
1996         child.as_posix.return_value = "/a/b/c"
1997         child.is_symlink.return_value = True
1998         try:
1999             list(
2000                 black.gen_python_files(
2001                     path.iterdir(),
2002                     root,
2003                     include,
2004                     exclude,
2005                     None,
2006                     None,
2007                     report,
2008                     gitignore,
2009                     verbose=False,
2010                     quiet=False,
2011                 )
2012             )
2013         except ValueError as ve:
2014             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2015         path.iterdir.assert_called_once()
2016         child.resolve.assert_called_once()
2017         child.is_symlink.assert_called_once()
2018         # `child` should behave like a strange file which resolved path is clearly
2019         # outside of the `root` directory.
2020         child.is_symlink.return_value = False
2021         with pytest.raises(ValueError):
2022             list(
2023                 black.gen_python_files(
2024                     path.iterdir(),
2025                     root,
2026                     include,
2027                     exclude,
2028                     None,
2029                     None,
2030                     report,
2031                     gitignore,
2032                     verbose=False,
2033                     quiet=False,
2034                 )
2035             )
2036         path.iterdir.assert_called()
2037         assert path.iterdir.call_count == 2
2038         child.resolve.assert_called()
2039         assert child.resolve.call_count == 2
2040         child.is_symlink.assert_called()
2041         assert child.is_symlink.call_count == 2
2042
2043     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2044     def test_get_sources_with_stdin(self) -> None:
2045         src = ["-"]
2046         expected = ["-"]
2047         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2048
2049     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2050     def test_get_sources_with_stdin_filename(self) -> None:
2051         src = ["-"]
2052         stdin_filename = str(THIS_DIR / "data/collections.py")
2053         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2054         assert_collected_sources(
2055             src,
2056             expected,
2057             exclude=r"/exclude/a\.py",
2058             stdin_filename=stdin_filename,
2059         )
2060
2061     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2062     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2063         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2064         # file being passed directly. This is the same as
2065         # test_exclude_for_issue_1572
2066         path = DATA_DIR / "include_exclude_tests"
2067         src = ["-"]
2068         stdin_filename = str(path / "b/exclude/a.py")
2069         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2070         assert_collected_sources(
2071             src,
2072             expected,
2073             exclude=r"/exclude/|a\.py",
2074             stdin_filename=stdin_filename,
2075         )
2076
2077     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2078     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2079         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2080         # file being passed directly. This is the same as
2081         # test_exclude_for_issue_1572
2082         src = ["-"]
2083         path = THIS_DIR / "data" / "include_exclude_tests"
2084         stdin_filename = str(path / "b/exclude/a.py")
2085         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2086         assert_collected_sources(
2087             src,
2088             expected,
2089             extend_exclude=r"/exclude/|a\.py",
2090             stdin_filename=stdin_filename,
2091         )
2092
2093     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2094     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2095         # Force exclude should exclude the file when passing it through
2096         # stdin_filename
2097         path = THIS_DIR / "data" / "include_exclude_tests"
2098         stdin_filename = str(path / "b/exclude/a.py")
2099         assert_collected_sources(
2100             src=["-"],
2101             expected=[],
2102             force_exclude=r"/exclude/|a\.py",
2103             stdin_filename=stdin_filename,
2104         )
2105
2106
2107 try:
2108     with open(black.__file__, "r", encoding="utf-8") as _bf:
2109         black_source_lines = _bf.readlines()
2110 except UnicodeDecodeError:
2111     if not black.COMPILED:
2112         raise
2113
2114
2115 def tracefunc(
2116     frame: types.FrameType, event: str, arg: Any
2117 ) -> Callable[[types.FrameType, str, Any], Any]:
2118     """Show function calls `from black/__init__.py` as they happen.
2119
2120     Register this with `sys.settrace()` in a test you're debugging.
2121     """
2122     if event != "call":
2123         return tracefunc
2124
2125     stack = len(inspect.stack()) - 19
2126     stack *= 2
2127     filename = frame.f_code.co_filename
2128     lineno = frame.f_lineno
2129     func_sig_lineno = lineno - 1
2130     funcname = black_source_lines[func_sig_lineno].strip()
2131     while funcname.startswith("@"):
2132         func_sig_lineno += 1
2133         funcname = black_source_lines[func_sig_lineno].strip()
2134     if "black/__init__.py" in filename:
2135         print(f"{' ' * stack}{lineno}:{funcname}")
2136     return tracefunc