]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add support for named exprs inside function calls as gen-exps (#3327)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     TypeVar,
29     Union,
30 )
31 from unittest.mock import MagicMock, patch
32
33 import click
34 import pytest
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     get_case_path,
63     read_data,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_detect_debug_f_strings(self) -> None:
314         root = black.lib2to3_parse("""f"{x=}" """)
315         features = black.get_features_used(root)
316         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
317         versions = black.detect_target_versions(root)
318         self.assertIn(black.TargetVersion.PY38, versions)
319
320         root = black.lib2to3_parse(
321             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
322         )
323         features = black.get_features_used(root)
324         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
325
326         # We don't yet support feature version detection in nested f-strings
327         root = black.lib2to3_parse(
328             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
329         )
330         features = black.get_features_used(root)
331         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_string_quotes(self) -> None:
335         source, expected = read_data("miscellaneous", "string_quotes")
336         mode = black.Mode(preview=True)
337         assert_format(source, expected, mode)
338         mode = replace(mode, string_normalization=False)
339         not_normalized = fs(source, mode=mode)
340         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
341         black.assert_equivalent(source, not_normalized)
342         black.assert_stable(source, not_normalized, mode=mode)
343
344     def test_skip_source_first_line(self) -> None:
345         source, _ = read_data("miscellaneous", "invalid_header")
346         tmp_file = Path(black.dump_to_file(source))
347         # Full source should fail (invalid syntax at header)
348         self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
349         # So, skipping the first line should work
350         result = BlackRunner().invoke(
351             black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
352         )
353         self.assertEqual(result.exit_code, 0)
354         with open(tmp_file, encoding="utf8") as f:
355             actual = f.read()
356         self.assertFormatEqual(source, actual)
357
358     def test_skip_source_first_line_when_mixing_newlines(self) -> None:
359         code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
360         expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
361         with TemporaryDirectory() as workspace:
362             test_file = Path(workspace) / "skip_header.py"
363             test_file.write_bytes(code_mixing_newlines)
364             mode = replace(DEFAULT_MODE, skip_source_first_line=True)
365             ff(test_file, mode=mode, write_back=black.WriteBack.YES)
366             self.assertEqual(test_file.read_bytes(), expected)
367
368     def test_skip_magic_trailing_comma(self) -> None:
369         source, _ = read_data("simple_cases", "expression")
370         expected, _ = read_data(
371             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
372         )
373         tmp_file = Path(black.dump_to_file(source))
374         diff_header = re.compile(
375             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
376             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
377         )
378         try:
379             result = BlackRunner().invoke(
380                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
381             )
382             self.assertEqual(result.exit_code, 0)
383         finally:
384             os.unlink(tmp_file)
385         actual = result.output
386         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
387         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
388         if expected != actual:
389             dump = black.dump_to_file(actual)
390             msg = (
391                 "Expected diff isn't equal to the actual. If you made changes to"
392                 " expression.py and this is an anticipated difference, overwrite"
393                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
394             )
395             self.assertEqual(expected, actual, msg)
396
397     @patch("black.dump_to_file", dump_to_stderr)
398     def test_async_as_identifier(self) -> None:
399         source_path = get_case_path("miscellaneous", "async_as_identifier")
400         source, expected = read_data_from_file(source_path)
401         actual = fs(source)
402         self.assertFormatEqual(expected, actual)
403         major, minor = sys.version_info[:2]
404         if major < 3 or (major <= 3 and minor < 7):
405             black.assert_equivalent(source, actual)
406         black.assert_stable(source, actual, DEFAULT_MODE)
407         # ensure black can parse this when the target is 3.6
408         self.invokeBlack([str(source_path), "--target-version", "py36"])
409         # but not on 3.7, because async/await is no longer an identifier
410         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_python37(self) -> None:
414         source_path = get_case_path("py_37", "python37")
415         source, expected = read_data_from_file(source_path)
416         actual = fs(source)
417         self.assertFormatEqual(expected, actual)
418         major, minor = sys.version_info[:2]
419         if major > 3 or (major == 3 and minor >= 7):
420             black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, DEFAULT_MODE)
422         # ensure black can parse this when the target is 3.7
423         self.invokeBlack([str(source_path), "--target-version", "py37"])
424         # but not on 3.6, because we use async as a reserved keyword
425         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
426
427     def test_tab_comment_indentation(self) -> None:
428         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
429         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
430         self.assertFormatEqual(contents_spc, fs(contents_spc))
431         self.assertFormatEqual(contents_spc, fs(contents_tab))
432
433         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
434         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
435         self.assertFormatEqual(contents_spc, fs(contents_spc))
436         self.assertFormatEqual(contents_spc, fs(contents_tab))
437
438         # mixed tabs and spaces (valid Python 2 code)
439         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
440         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
441         self.assertFormatEqual(contents_spc, fs(contents_spc))
442         self.assertFormatEqual(contents_spc, fs(contents_tab))
443
444         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
445         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
446         self.assertFormatEqual(contents_spc, fs(contents_spc))
447         self.assertFormatEqual(contents_spc, fs(contents_tab))
448
449     def test_report_verbose(self) -> None:
450         report = Report(verbose=True)
451         out_lines = []
452         err_lines = []
453
454         def out(msg: str, **kwargs: Any) -> None:
455             out_lines.append(msg)
456
457         def err(msg: str, **kwargs: Any) -> None:
458             err_lines.append(msg)
459
460         with patch("black.output._out", out), patch("black.output._err", err):
461             report.done(Path("f1"), black.Changed.NO)
462             self.assertEqual(len(out_lines), 1)
463             self.assertEqual(len(err_lines), 0)
464             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
465             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
466             self.assertEqual(report.return_code, 0)
467             report.done(Path("f2"), black.Changed.YES)
468             self.assertEqual(len(out_lines), 2)
469             self.assertEqual(len(err_lines), 0)
470             self.assertEqual(out_lines[-1], "reformatted f2")
471             self.assertEqual(
472                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
473             )
474             report.done(Path("f3"), black.Changed.CACHED)
475             self.assertEqual(len(out_lines), 3)
476             self.assertEqual(len(err_lines), 0)
477             self.assertEqual(
478                 out_lines[-1], "f3 wasn't modified on disk since last run."
479             )
480             self.assertEqual(
481                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
482             )
483             self.assertEqual(report.return_code, 0)
484             report.check = True
485             self.assertEqual(report.return_code, 1)
486             report.check = False
487             report.failed(Path("e1"), "boom")
488             self.assertEqual(len(out_lines), 3)
489             self.assertEqual(len(err_lines), 1)
490             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
491             self.assertEqual(
492                 unstyle(str(report)),
493                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
494                 " reformat.",
495             )
496             self.assertEqual(report.return_code, 123)
497             report.done(Path("f3"), black.Changed.YES)
498             self.assertEqual(len(out_lines), 4)
499             self.assertEqual(len(err_lines), 1)
500             self.assertEqual(out_lines[-1], "reformatted f3")
501             self.assertEqual(
502                 unstyle(str(report)),
503                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
504                 " reformat.",
505             )
506             self.assertEqual(report.return_code, 123)
507             report.failed(Path("e2"), "boom")
508             self.assertEqual(len(out_lines), 4)
509             self.assertEqual(len(err_lines), 2)
510             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
511             self.assertEqual(
512                 unstyle(str(report)),
513                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
514                 " reformat.",
515             )
516             self.assertEqual(report.return_code, 123)
517             report.path_ignored(Path("wat"), "no match")
518             self.assertEqual(len(out_lines), 5)
519             self.assertEqual(len(err_lines), 2)
520             self.assertEqual(out_lines[-1], "wat ignored: no match")
521             self.assertEqual(
522                 unstyle(str(report)),
523                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
524                 " reformat.",
525             )
526             self.assertEqual(report.return_code, 123)
527             report.done(Path("f4"), black.Changed.NO)
528             self.assertEqual(len(out_lines), 6)
529             self.assertEqual(len(err_lines), 2)
530             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
531             self.assertEqual(
532                 unstyle(str(report)),
533                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
534                 " reformat.",
535             )
536             self.assertEqual(report.return_code, 123)
537             report.check = True
538             self.assertEqual(
539                 unstyle(str(report)),
540                 "2 files would be reformatted, 3 files would be left unchanged, 2"
541                 " files would fail to reformat.",
542             )
543             report.check = False
544             report.diff = True
545             self.assertEqual(
546                 unstyle(str(report)),
547                 "2 files would be reformatted, 3 files would be left unchanged, 2"
548                 " files would fail to reformat.",
549             )
550
551     def test_report_quiet(self) -> None:
552         report = Report(quiet=True)
553         out_lines = []
554         err_lines = []
555
556         def out(msg: str, **kwargs: Any) -> None:
557             out_lines.append(msg)
558
559         def err(msg: str, **kwargs: Any) -> None:
560             err_lines.append(msg)
561
562         with patch("black.output._out", out), patch("black.output._err", err):
563             report.done(Path("f1"), black.Changed.NO)
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 0)
566             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
567             self.assertEqual(report.return_code, 0)
568             report.done(Path("f2"), black.Changed.YES)
569             self.assertEqual(len(out_lines), 0)
570             self.assertEqual(len(err_lines), 0)
571             self.assertEqual(
572                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
573             )
574             report.done(Path("f3"), black.Changed.CACHED)
575             self.assertEqual(len(out_lines), 0)
576             self.assertEqual(len(err_lines), 0)
577             self.assertEqual(
578                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
579             )
580             self.assertEqual(report.return_code, 0)
581             report.check = True
582             self.assertEqual(report.return_code, 1)
583             report.check = False
584             report.failed(Path("e1"), "boom")
585             self.assertEqual(len(out_lines), 0)
586             self.assertEqual(len(err_lines), 1)
587             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
588             self.assertEqual(
589                 unstyle(str(report)),
590                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
591                 " reformat.",
592             )
593             self.assertEqual(report.return_code, 123)
594             report.done(Path("f3"), black.Changed.YES)
595             self.assertEqual(len(out_lines), 0)
596             self.assertEqual(len(err_lines), 1)
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
600                 " reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.failed(Path("e2"), "boom")
604             self.assertEqual(len(out_lines), 0)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
607             self.assertEqual(
608                 unstyle(str(report)),
609                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
610                 " reformat.",
611             )
612             self.assertEqual(report.return_code, 123)
613             report.path_ignored(Path("wat"), "no match")
614             self.assertEqual(len(out_lines), 0)
615             self.assertEqual(len(err_lines), 2)
616             self.assertEqual(
617                 unstyle(str(report)),
618                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
619                 " reformat.",
620             )
621             self.assertEqual(report.return_code, 123)
622             report.done(Path("f4"), black.Changed.NO)
623             self.assertEqual(len(out_lines), 0)
624             self.assertEqual(len(err_lines), 2)
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
628                 " reformat.",
629             )
630             self.assertEqual(report.return_code, 123)
631             report.check = True
632             self.assertEqual(
633                 unstyle(str(report)),
634                 "2 files would be reformatted, 3 files would be left unchanged, 2"
635                 " files would fail to reformat.",
636             )
637             report.check = False
638             report.diff = True
639             self.assertEqual(
640                 unstyle(str(report)),
641                 "2 files would be reformatted, 3 files would be left unchanged, 2"
642                 " files would fail to reformat.",
643             )
644
645     def test_report_normal(self) -> None:
646         report = black.Report()
647         out_lines = []
648         err_lines = []
649
650         def out(msg: str, **kwargs: Any) -> None:
651             out_lines.append(msg)
652
653         def err(msg: str, **kwargs: Any) -> None:
654             err_lines.append(msg)
655
656         with patch("black.output._out", out), patch("black.output._err", err):
657             report.done(Path("f1"), black.Changed.NO)
658             self.assertEqual(len(out_lines), 0)
659             self.assertEqual(len(err_lines), 0)
660             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
661             self.assertEqual(report.return_code, 0)
662             report.done(Path("f2"), black.Changed.YES)
663             self.assertEqual(len(out_lines), 1)
664             self.assertEqual(len(err_lines), 0)
665             self.assertEqual(out_lines[-1], "reformatted f2")
666             self.assertEqual(
667                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
668             )
669             report.done(Path("f3"), black.Changed.CACHED)
670             self.assertEqual(len(out_lines), 1)
671             self.assertEqual(len(err_lines), 0)
672             self.assertEqual(out_lines[-1], "reformatted f2")
673             self.assertEqual(
674                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
675             )
676             self.assertEqual(report.return_code, 0)
677             report.check = True
678             self.assertEqual(report.return_code, 1)
679             report.check = False
680             report.failed(Path("e1"), "boom")
681             self.assertEqual(len(out_lines), 1)
682             self.assertEqual(len(err_lines), 1)
683             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
684             self.assertEqual(
685                 unstyle(str(report)),
686                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
687                 " reformat.",
688             )
689             self.assertEqual(report.return_code, 123)
690             report.done(Path("f3"), black.Changed.YES)
691             self.assertEqual(len(out_lines), 2)
692             self.assertEqual(len(err_lines), 1)
693             self.assertEqual(out_lines[-1], "reformatted f3")
694             self.assertEqual(
695                 unstyle(str(report)),
696                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
697                 " reformat.",
698             )
699             self.assertEqual(report.return_code, 123)
700             report.failed(Path("e2"), "boom")
701             self.assertEqual(len(out_lines), 2)
702             self.assertEqual(len(err_lines), 2)
703             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
704             self.assertEqual(
705                 unstyle(str(report)),
706                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
707                 " reformat.",
708             )
709             self.assertEqual(report.return_code, 123)
710             report.path_ignored(Path("wat"), "no match")
711             self.assertEqual(len(out_lines), 2)
712             self.assertEqual(len(err_lines), 2)
713             self.assertEqual(
714                 unstyle(str(report)),
715                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
716                 " reformat.",
717             )
718             self.assertEqual(report.return_code, 123)
719             report.done(Path("f4"), black.Changed.NO)
720             self.assertEqual(len(out_lines), 2)
721             self.assertEqual(len(err_lines), 2)
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
725                 " reformat.",
726             )
727             self.assertEqual(report.return_code, 123)
728             report.check = True
729             self.assertEqual(
730                 unstyle(str(report)),
731                 "2 files would be reformatted, 3 files would be left unchanged, 2"
732                 " files would fail to reformat.",
733             )
734             report.check = False
735             report.diff = True
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files would be reformatted, 3 files would be left unchanged, 2"
739                 " files would fail to reformat.",
740             )
741
742     def test_lib2to3_parse(self) -> None:
743         with self.assertRaises(black.InvalidInput):
744             black.lib2to3_parse("invalid syntax")
745
746         straddling = "x + y"
747         black.lib2to3_parse(straddling)
748         black.lib2to3_parse(straddling, {TargetVersion.PY36})
749
750         py2_only = "print x"
751         with self.assertRaises(black.InvalidInput):
752             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
753
754         py3_only = "exec(x, end=y)"
755         black.lib2to3_parse(py3_only)
756         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
757
758     def test_get_features_used_decorator(self) -> None:
759         # Test the feature detection of new decorator syntax
760         # since this makes some test cases of test_get_features_used()
761         # fails if it fails, this is tested first so that a useful case
762         # is identified
763         simples, relaxed = read_data("miscellaneous", "decorators")
764         # skip explanation comments at the top of the file
765         for simple_test in simples.split("##")[1:]:
766             node = black.lib2to3_parse(simple_test)
767             decorator = str(node.children[0].children[0]).strip()
768             self.assertNotIn(
769                 Feature.RELAXED_DECORATORS,
770                 black.get_features_used(node),
771                 msg=(
772                     f"decorator '{decorator}' follows python<=3.8 syntax"
773                     "but is detected as 3.9+"
774                     # f"The full node is\n{node!r}"
775                 ),
776             )
777         # skip the '# output' comment at the top of the output part
778         for relaxed_test in relaxed.split("##")[1:]:
779             node = black.lib2to3_parse(relaxed_test)
780             decorator = str(node.children[0].children[0]).strip()
781             self.assertIn(
782                 Feature.RELAXED_DECORATORS,
783                 black.get_features_used(node),
784                 msg=(
785                     f"decorator '{decorator}' uses python3.9+ syntax"
786                     "but is detected as python<=3.8"
787                     # f"The full node is\n{node!r}"
788                 ),
789             )
790
791     def test_get_features_used(self) -> None:
792         node = black.lib2to3_parse("def f(*, arg): ...\n")
793         self.assertEqual(black.get_features_used(node), set())
794         node = black.lib2to3_parse("def f(*, arg,): ...\n")
795         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
796         node = black.lib2to3_parse("f(*arg,)\n")
797         self.assertEqual(
798             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
799         )
800         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
801         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
802         node = black.lib2to3_parse("123_456\n")
803         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
804         node = black.lib2to3_parse("123456\n")
805         self.assertEqual(black.get_features_used(node), set())
806         source, expected = read_data("simple_cases", "function")
807         node = black.lib2to3_parse(source)
808         expected_features = {
809             Feature.TRAILING_COMMA_IN_CALL,
810             Feature.TRAILING_COMMA_IN_DEF,
811             Feature.F_STRINGS,
812         }
813         self.assertEqual(black.get_features_used(node), expected_features)
814         node = black.lib2to3_parse(expected)
815         self.assertEqual(black.get_features_used(node), expected_features)
816         source, expected = read_data("simple_cases", "expression")
817         node = black.lib2to3_parse(source)
818         self.assertEqual(black.get_features_used(node), set())
819         node = black.lib2to3_parse(expected)
820         self.assertEqual(black.get_features_used(node), set())
821         node = black.lib2to3_parse("lambda a, /, b: ...")
822         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
823         node = black.lib2to3_parse("def fn(a, /, b): ...")
824         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
825         node = black.lib2to3_parse("def fn(): yield a, b")
826         self.assertEqual(black.get_features_used(node), set())
827         node = black.lib2to3_parse("def fn(): return a, b")
828         self.assertEqual(black.get_features_used(node), set())
829         node = black.lib2to3_parse("def fn(): yield *b, c")
830         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
831         node = black.lib2to3_parse("def fn(): return a, *b, c")
832         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
833         node = black.lib2to3_parse("x = a, *b, c")
834         self.assertEqual(black.get_features_used(node), set())
835         node = black.lib2to3_parse("x: Any = regular")
836         self.assertEqual(black.get_features_used(node), set())
837         node = black.lib2to3_parse("x: Any = (regular, regular)")
838         self.assertEqual(black.get_features_used(node), set())
839         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
840         self.assertEqual(black.get_features_used(node), set())
841         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
842         self.assertEqual(
843             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
844         )
845         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
846         self.assertEqual(black.get_features_used(node), set())
847         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
848         self.assertEqual(black.get_features_used(node), set())
849         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
850         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
851         node = black.lib2to3_parse("a[*b]")
852         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
853         node = black.lib2to3_parse("a[x, *y(), z] = t")
854         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
855         node = black.lib2to3_parse("def fn(*args: *T): pass")
856         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
857
858     def test_get_features_used_for_future_flags(self) -> None:
859         for src, features in [
860             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
861             (
862                 "from __future__ import (other, annotations)",
863                 {Feature.FUTURE_ANNOTATIONS},
864             ),
865             ("a = 1 + 2\nfrom something import annotations", set()),
866             ("from __future__ import x, y", set()),
867         ]:
868             with self.subTest(src=src, features=features):
869                 node = black.lib2to3_parse(src)
870                 future_imports = black.get_future_imports(node)
871                 self.assertEqual(
872                     black.get_features_used(node, future_imports=future_imports),
873                     features,
874                 )
875
876     def test_get_future_imports(self) -> None:
877         node = black.lib2to3_parse("\n")
878         self.assertEqual(set(), black.get_future_imports(node))
879         node = black.lib2to3_parse("from __future__ import black\n")
880         self.assertEqual({"black"}, black.get_future_imports(node))
881         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
882         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
883         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
884         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
885         node = black.lib2to3_parse(
886             "from __future__ import multiple\nfrom __future__ import imports\n"
887         )
888         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
889         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
890         self.assertEqual({"black"}, black.get_future_imports(node))
891         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
892         self.assertEqual({"black"}, black.get_future_imports(node))
893         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
894         self.assertEqual(set(), black.get_future_imports(node))
895         node = black.lib2to3_parse("from some.module import black\n")
896         self.assertEqual(set(), black.get_future_imports(node))
897         node = black.lib2to3_parse(
898             "from __future__ import unicode_literals as _unicode_literals"
899         )
900         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
901         node = black.lib2to3_parse(
902             "from __future__ import unicode_literals as _lol, print"
903         )
904         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
905
906     @pytest.mark.incompatible_with_mypyc
907     def test_debug_visitor(self) -> None:
908         source, _ = read_data("miscellaneous", "debug_visitor")
909         expected, _ = read_data("miscellaneous", "debug_visitor.out")
910         out_lines = []
911         err_lines = []
912
913         def out(msg: str, **kwargs: Any) -> None:
914             out_lines.append(msg)
915
916         def err(msg: str, **kwargs: Any) -> None:
917             err_lines.append(msg)
918
919         with patch("black.debug.out", out):
920             DebugVisitor.show(source)
921         actual = "\n".join(out_lines) + "\n"
922         log_name = ""
923         if expected != actual:
924             log_name = black.dump_to_file(*out_lines)
925         self.assertEqual(
926             expected,
927             actual,
928             f"AST print out is different. Actual version dumped to {log_name}",
929         )
930
931     def test_format_file_contents(self) -> None:
932         empty = ""
933         mode = DEFAULT_MODE
934         with self.assertRaises(black.NothingChanged):
935             black.format_file_contents(empty, mode=mode, fast=False)
936         just_nl = "\n"
937         with self.assertRaises(black.NothingChanged):
938             black.format_file_contents(just_nl, mode=mode, fast=False)
939         same = "j = [1, 2, 3]\n"
940         with self.assertRaises(black.NothingChanged):
941             black.format_file_contents(same, mode=mode, fast=False)
942         different = "j = [1,2,3]"
943         expected = same
944         actual = black.format_file_contents(different, mode=mode, fast=False)
945         self.assertEqual(expected, actual)
946         invalid = "return if you can"
947         with self.assertRaises(black.InvalidInput) as e:
948             black.format_file_contents(invalid, mode=mode, fast=False)
949         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
950
951     def test_endmarker(self) -> None:
952         n = black.lib2to3_parse("\n")
953         self.assertEqual(n.type, black.syms.file_input)
954         self.assertEqual(len(n.children), 1)
955         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
956
957     @pytest.mark.incompatible_with_mypyc
958     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
959     def test_assertFormatEqual(self) -> None:
960         out_lines = []
961         err_lines = []
962
963         def out(msg: str, **kwargs: Any) -> None:
964             out_lines.append(msg)
965
966         def err(msg: str, **kwargs: Any) -> None:
967             err_lines.append(msg)
968
969         with patch("black.output._out", out), patch("black.output._err", err):
970             with self.assertRaises(AssertionError):
971                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
972
973         out_str = "".join(out_lines)
974         self.assertIn("Expected tree:", out_str)
975         self.assertIn("Actual tree:", out_str)
976         self.assertEqual("".join(err_lines), "")
977
978     @event_loop()
979     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
980     def test_works_in_mono_process_only_environment(self) -> None:
981         with cache_dir() as workspace:
982             for f in [
983                 (workspace / "one.py").resolve(),
984                 (workspace / "two.py").resolve(),
985             ]:
986                 f.write_text('print("hello")\n')
987             self.invokeBlack([str(workspace)])
988
989     @event_loop()
990     def test_check_diff_use_together(self) -> None:
991         with cache_dir():
992             # Files which will be reformatted.
993             src1 = get_case_path("miscellaneous", "string_quotes")
994             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
995             # Files which will not be reformatted.
996             src2 = get_case_path("simple_cases", "composition")
997             self.invokeBlack([str(src2), "--diff", "--check"])
998             # Multi file command.
999             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1000
1001     def test_no_src_fails(self) -> None:
1002         with cache_dir():
1003             self.invokeBlack([], exit_code=1)
1004
1005     def test_src_and_code_fails(self) -> None:
1006         with cache_dir():
1007             self.invokeBlack([".", "-c", "0"], exit_code=1)
1008
1009     def test_broken_symlink(self) -> None:
1010         with cache_dir() as workspace:
1011             symlink = workspace / "broken_link.py"
1012             try:
1013                 symlink.symlink_to("nonexistent.py")
1014             except (OSError, NotImplementedError) as e:
1015                 self.skipTest(f"Can't create symlinks: {e}")
1016             self.invokeBlack([str(workspace.resolve())])
1017
1018     def test_single_file_force_pyi(self) -> None:
1019         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1020         contents, expected = read_data("miscellaneous", "force_pyi")
1021         with cache_dir() as workspace:
1022             path = (workspace / "file.py").resolve()
1023             with open(path, "w") as fh:
1024                 fh.write(contents)
1025             self.invokeBlack([str(path), "--pyi"])
1026             with open(path, "r") as fh:
1027                 actual = fh.read()
1028             # verify cache with --pyi is separate
1029             pyi_cache = black.read_cache(pyi_mode)
1030             self.assertIn(str(path), pyi_cache)
1031             normal_cache = black.read_cache(DEFAULT_MODE)
1032             self.assertNotIn(str(path), normal_cache)
1033         self.assertFormatEqual(expected, actual)
1034         black.assert_equivalent(contents, actual)
1035         black.assert_stable(contents, actual, pyi_mode)
1036
1037     @event_loop()
1038     def test_multi_file_force_pyi(self) -> None:
1039         reg_mode = DEFAULT_MODE
1040         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1041         contents, expected = read_data("miscellaneous", "force_pyi")
1042         with cache_dir() as workspace:
1043             paths = [
1044                 (workspace / "file1.py").resolve(),
1045                 (workspace / "file2.py").resolve(),
1046             ]
1047             for path in paths:
1048                 with open(path, "w") as fh:
1049                     fh.write(contents)
1050             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1051             for path in paths:
1052                 with open(path, "r") as fh:
1053                     actual = fh.read()
1054                 self.assertEqual(actual, expected)
1055             # verify cache with --pyi is separate
1056             pyi_cache = black.read_cache(pyi_mode)
1057             normal_cache = black.read_cache(reg_mode)
1058             for path in paths:
1059                 self.assertIn(str(path), pyi_cache)
1060                 self.assertNotIn(str(path), normal_cache)
1061
1062     def test_pipe_force_pyi(self) -> None:
1063         source, expected = read_data("miscellaneous", "force_pyi")
1064         result = CliRunner().invoke(
1065             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1066         )
1067         self.assertEqual(result.exit_code, 0)
1068         actual = result.output
1069         self.assertFormatEqual(actual, expected)
1070
1071     def test_single_file_force_py36(self) -> None:
1072         reg_mode = DEFAULT_MODE
1073         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1074         source, expected = read_data("miscellaneous", "force_py36")
1075         with cache_dir() as workspace:
1076             path = (workspace / "file.py").resolve()
1077             with open(path, "w") as fh:
1078                 fh.write(source)
1079             self.invokeBlack([str(path), *PY36_ARGS])
1080             with open(path, "r") as fh:
1081                 actual = fh.read()
1082             # verify cache with --target-version is separate
1083             py36_cache = black.read_cache(py36_mode)
1084             self.assertIn(str(path), py36_cache)
1085             normal_cache = black.read_cache(reg_mode)
1086             self.assertNotIn(str(path), normal_cache)
1087         self.assertEqual(actual, expected)
1088
1089     @event_loop()
1090     def test_multi_file_force_py36(self) -> None:
1091         reg_mode = DEFAULT_MODE
1092         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1093         source, expected = read_data("miscellaneous", "force_py36")
1094         with cache_dir() as workspace:
1095             paths = [
1096                 (workspace / "file1.py").resolve(),
1097                 (workspace / "file2.py").resolve(),
1098             ]
1099             for path in paths:
1100                 with open(path, "w") as fh:
1101                     fh.write(source)
1102             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1103             for path in paths:
1104                 with open(path, "r") as fh:
1105                     actual = fh.read()
1106                 self.assertEqual(actual, expected)
1107             # verify cache with --target-version is separate
1108             pyi_cache = black.read_cache(py36_mode)
1109             normal_cache = black.read_cache(reg_mode)
1110             for path in paths:
1111                 self.assertIn(str(path), pyi_cache)
1112                 self.assertNotIn(str(path), normal_cache)
1113
1114     def test_pipe_force_py36(self) -> None:
1115         source, expected = read_data("miscellaneous", "force_py36")
1116         result = CliRunner().invoke(
1117             black.main,
1118             ["-", "-q", "--target-version=py36"],
1119             input=BytesIO(source.encode("utf8")),
1120         )
1121         self.assertEqual(result.exit_code, 0)
1122         actual = result.output
1123         self.assertFormatEqual(actual, expected)
1124
1125     @pytest.mark.incompatible_with_mypyc
1126     def test_reformat_one_with_stdin(self) -> None:
1127         with patch(
1128             "black.format_stdin_to_stdout",
1129             return_value=lambda *args, **kwargs: black.Changed.YES,
1130         ) as fsts:
1131             report = MagicMock()
1132             path = Path("-")
1133             black.reformat_one(
1134                 path,
1135                 fast=True,
1136                 write_back=black.WriteBack.YES,
1137                 mode=DEFAULT_MODE,
1138                 report=report,
1139             )
1140             fsts.assert_called_once()
1141             report.done.assert_called_with(path, black.Changed.YES)
1142
1143     @pytest.mark.incompatible_with_mypyc
1144     def test_reformat_one_with_stdin_filename(self) -> None:
1145         with patch(
1146             "black.format_stdin_to_stdout",
1147             return_value=lambda *args, **kwargs: black.Changed.YES,
1148         ) as fsts:
1149             report = MagicMock()
1150             p = "foo.py"
1151             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1152             expected = Path(p)
1153             black.reformat_one(
1154                 path,
1155                 fast=True,
1156                 write_back=black.WriteBack.YES,
1157                 mode=DEFAULT_MODE,
1158                 report=report,
1159             )
1160             fsts.assert_called_once_with(
1161                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1162             )
1163             # __BLACK_STDIN_FILENAME__ should have been stripped
1164             report.done.assert_called_with(expected, black.Changed.YES)
1165
1166     @pytest.mark.incompatible_with_mypyc
1167     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1168         with patch(
1169             "black.format_stdin_to_stdout",
1170             return_value=lambda *args, **kwargs: black.Changed.YES,
1171         ) as fsts:
1172             report = MagicMock()
1173             p = "foo.pyi"
1174             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1175             expected = Path(p)
1176             black.reformat_one(
1177                 path,
1178                 fast=True,
1179                 write_back=black.WriteBack.YES,
1180                 mode=DEFAULT_MODE,
1181                 report=report,
1182             )
1183             fsts.assert_called_once_with(
1184                 fast=True,
1185                 write_back=black.WriteBack.YES,
1186                 mode=replace(DEFAULT_MODE, is_pyi=True),
1187             )
1188             # __BLACK_STDIN_FILENAME__ should have been stripped
1189             report.done.assert_called_with(expected, black.Changed.YES)
1190
1191     @pytest.mark.incompatible_with_mypyc
1192     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1193         with patch(
1194             "black.format_stdin_to_stdout",
1195             return_value=lambda *args, **kwargs: black.Changed.YES,
1196         ) as fsts:
1197             report = MagicMock()
1198             p = "foo.ipynb"
1199             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1200             expected = Path(p)
1201             black.reformat_one(
1202                 path,
1203                 fast=True,
1204                 write_back=black.WriteBack.YES,
1205                 mode=DEFAULT_MODE,
1206                 report=report,
1207             )
1208             fsts.assert_called_once_with(
1209                 fast=True,
1210                 write_back=black.WriteBack.YES,
1211                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1212             )
1213             # __BLACK_STDIN_FILENAME__ should have been stripped
1214             report.done.assert_called_with(expected, black.Changed.YES)
1215
1216     @pytest.mark.incompatible_with_mypyc
1217     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1218         with patch(
1219             "black.format_stdin_to_stdout",
1220             return_value=lambda *args, **kwargs: black.Changed.YES,
1221         ) as fsts:
1222             report = MagicMock()
1223             # Even with an existing file, since we are forcing stdin, black
1224             # should output to stdout and not modify the file inplace
1225             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1226             # Make sure is_file actually returns True
1227             self.assertTrue(p.is_file())
1228             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1229             expected = Path(p)
1230             black.reformat_one(
1231                 path,
1232                 fast=True,
1233                 write_back=black.WriteBack.YES,
1234                 mode=DEFAULT_MODE,
1235                 report=report,
1236             )
1237             fsts.assert_called_once()
1238             # __BLACK_STDIN_FILENAME__ should have been stripped
1239             report.done.assert_called_with(expected, black.Changed.YES)
1240
1241     def test_reformat_one_with_stdin_empty(self) -> None:
1242         output = io.StringIO()
1243         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1244             try:
1245                 black.format_stdin_to_stdout(
1246                     fast=True,
1247                     content="",
1248                     write_back=black.WriteBack.YES,
1249                     mode=DEFAULT_MODE,
1250                 )
1251             except io.UnsupportedOperation:
1252                 pass  # StringIO does not support detach
1253             assert output.getvalue() == ""
1254
1255     def test_invalid_cli_regex(self) -> None:
1256         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1257             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1258
1259     def test_required_version_matches_version(self) -> None:
1260         self.invokeBlack(
1261             ["--required-version", black.__version__, "-c", "0"],
1262             exit_code=0,
1263             ignore_config=True,
1264         )
1265
1266     def test_required_version_matches_partial_version(self) -> None:
1267         self.invokeBlack(
1268             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1269             exit_code=0,
1270             ignore_config=True,
1271         )
1272
1273     def test_required_version_does_not_match_on_minor_version(self) -> None:
1274         self.invokeBlack(
1275             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1276             exit_code=1,
1277             ignore_config=True,
1278         )
1279
1280     def test_required_version_does_not_match_version(self) -> None:
1281         result = BlackRunner().invoke(
1282             black.main,
1283             ["--required-version", "20.99b", "-c", "0"],
1284         )
1285         self.assertEqual(result.exit_code, 1)
1286         self.assertIn("required version", result.stderr)
1287
1288     def test_preserves_line_endings(self) -> None:
1289         with TemporaryDirectory() as workspace:
1290             test_file = Path(workspace) / "test.py"
1291             for nl in ["\n", "\r\n"]:
1292                 contents = nl.join(["def f(  ):", "    pass"])
1293                 test_file.write_bytes(contents.encode())
1294                 ff(test_file, write_back=black.WriteBack.YES)
1295                 updated_contents: bytes = test_file.read_bytes()
1296                 self.assertIn(nl.encode(), updated_contents)
1297                 if nl == "\n":
1298                     self.assertNotIn(b"\r\n", updated_contents)
1299
1300     def test_preserves_line_endings_via_stdin(self) -> None:
1301         for nl in ["\n", "\r\n"]:
1302             contents = nl.join(["def f(  ):", "    pass"])
1303             runner = BlackRunner()
1304             result = runner.invoke(
1305                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1306             )
1307             self.assertEqual(result.exit_code, 0)
1308             output = result.stdout_bytes
1309             self.assertIn(nl.encode("utf8"), output)
1310             if nl == "\n":
1311                 self.assertNotIn(b"\r\n", output)
1312
1313     def test_normalize_line_endings(self) -> None:
1314         with TemporaryDirectory() as workspace:
1315             test_file = Path(workspace) / "test.py"
1316             for data, expected in (
1317                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1318                 (b"l\nl\r\n ", b"l\nl\n"),
1319             ):
1320                 test_file.write_bytes(data)
1321                 ff(test_file, write_back=black.WriteBack.YES)
1322                 self.assertEqual(test_file.read_bytes(), expected)
1323
1324     def test_assert_equivalent_different_asts(self) -> None:
1325         with self.assertRaises(AssertionError):
1326             black.assert_equivalent("{}", "None")
1327
1328     def test_shhh_click(self) -> None:
1329         try:
1330             from click import _unicodefun  # type: ignore
1331         except ImportError:
1332             self.skipTest("Incompatible Click version")
1333
1334         if not hasattr(_unicodefun, "_verify_python_env"):
1335             self.skipTest("Incompatible Click version")
1336
1337         # First, let's see if Click is crashing with a preferred ASCII charset.
1338         with patch("locale.getpreferredencoding") as gpe:
1339             gpe.return_value = "ASCII"
1340             with self.assertRaises(RuntimeError):
1341                 _unicodefun._verify_python_env()
1342         # Now, let's silence Click...
1343         black.patch_click()
1344         # ...and confirm it's silent.
1345         with patch("locale.getpreferredencoding") as gpe:
1346             gpe.return_value = "ASCII"
1347             try:
1348                 _unicodefun._verify_python_env()
1349             except RuntimeError as re:
1350                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1351
1352     def test_root_logger_not_used_directly(self) -> None:
1353         def fail(*args: Any, **kwargs: Any) -> None:
1354             self.fail("Record created with root logger")
1355
1356         with patch.multiple(
1357             logging.root,
1358             debug=fail,
1359             info=fail,
1360             warning=fail,
1361             error=fail,
1362             critical=fail,
1363             log=fail,
1364         ):
1365             ff(THIS_DIR / "util.py")
1366
1367     def test_invalid_config_return_code(self) -> None:
1368         tmp_file = Path(black.dump_to_file())
1369         try:
1370             tmp_config = Path(black.dump_to_file())
1371             tmp_config.unlink()
1372             args = ["--config", str(tmp_config), str(tmp_file)]
1373             self.invokeBlack(args, exit_code=2, ignore_config=False)
1374         finally:
1375             tmp_file.unlink()
1376
1377     def test_parse_pyproject_toml(self) -> None:
1378         test_toml_file = THIS_DIR / "test.toml"
1379         config = black.parse_pyproject_toml(str(test_toml_file))
1380         self.assertEqual(config["verbose"], 1)
1381         self.assertEqual(config["check"], "no")
1382         self.assertEqual(config["diff"], "y")
1383         self.assertEqual(config["color"], True)
1384         self.assertEqual(config["line_length"], 79)
1385         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1386         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1387         self.assertEqual(config["exclude"], r"\.pyi?$")
1388         self.assertEqual(config["include"], r"\.py?$")
1389
1390     def test_read_pyproject_toml(self) -> None:
1391         test_toml_file = THIS_DIR / "test.toml"
1392         fake_ctx = FakeContext()
1393         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1394         config = fake_ctx.default_map
1395         self.assertEqual(config["verbose"], "1")
1396         self.assertEqual(config["check"], "no")
1397         self.assertEqual(config["diff"], "y")
1398         self.assertEqual(config["color"], "True")
1399         self.assertEqual(config["line_length"], "79")
1400         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1401         self.assertEqual(config["exclude"], r"\.pyi?$")
1402         self.assertEqual(config["include"], r"\.py?$")
1403
1404     @pytest.mark.incompatible_with_mypyc
1405     def test_find_project_root(self) -> None:
1406         with TemporaryDirectory() as workspace:
1407             root = Path(workspace)
1408             test_dir = root / "test"
1409             test_dir.mkdir()
1410
1411             src_dir = root / "src"
1412             src_dir.mkdir()
1413
1414             root_pyproject = root / "pyproject.toml"
1415             root_pyproject.touch()
1416             src_pyproject = src_dir / "pyproject.toml"
1417             src_pyproject.touch()
1418             src_python = src_dir / "foo.py"
1419             src_python.touch()
1420
1421             self.assertEqual(
1422                 black.find_project_root((src_dir, test_dir)),
1423                 (root.resolve(), "pyproject.toml"),
1424             )
1425             self.assertEqual(
1426                 black.find_project_root((src_dir,)),
1427                 (src_dir.resolve(), "pyproject.toml"),
1428             )
1429             self.assertEqual(
1430                 black.find_project_root((src_python,)),
1431                 (src_dir.resolve(), "pyproject.toml"),
1432             )
1433
1434             with change_directory(test_dir):
1435                 self.assertEqual(
1436                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1437                     (src_dir.resolve(), "pyproject.toml"),
1438                 )
1439
1440     @patch(
1441         "black.files.find_user_pyproject_toml",
1442     )
1443     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1444         find_user_pyproject_toml.side_effect = RuntimeError()
1445
1446         with redirect_stderr(io.StringIO()) as stderr:
1447             result = black.files.find_pyproject_toml(
1448                 path_search_start=(str(Path.cwd().root),)
1449             )
1450
1451         assert result is None
1452         err = stderr.getvalue()
1453         assert "Ignoring user configuration" in err
1454
1455     @patch(
1456         "black.files.find_user_pyproject_toml",
1457         black.files.find_user_pyproject_toml.__wrapped__,
1458     )
1459     def test_find_user_pyproject_toml_linux(self) -> None:
1460         if system() == "Windows":
1461             return
1462
1463         # Test if XDG_CONFIG_HOME is checked
1464         with TemporaryDirectory() as workspace:
1465             tmp_user_config = Path(workspace) / "black"
1466             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1467                 self.assertEqual(
1468                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1469                 )
1470
1471         # Test fallback for XDG_CONFIG_HOME
1472         with patch.dict("os.environ"):
1473             os.environ.pop("XDG_CONFIG_HOME", None)
1474             fallback_user_config = Path("~/.config").expanduser() / "black"
1475             self.assertEqual(
1476                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1477             )
1478
1479     def test_find_user_pyproject_toml_windows(self) -> None:
1480         if system() != "Windows":
1481             return
1482
1483         user_config_path = Path.home() / ".black"
1484         self.assertEqual(
1485             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1486         )
1487
1488     def test_bpo_33660_workaround(self) -> None:
1489         if system() == "Windows":
1490             return
1491
1492         # https://bugs.python.org/issue33660
1493         root = Path("/")
1494         with change_directory(root):
1495             path = Path("workspace") / "project"
1496             report = black.Report(verbose=True)
1497             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1498             self.assertEqual(normalized_path, "workspace/project")
1499
1500     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1501         if system() != "Windows":
1502             return
1503
1504         with TemporaryDirectory() as workspace:
1505             root = Path(workspace)
1506             junction_dir = root / "junction"
1507             junction_target_outside_of_root = root / ".."
1508             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1509
1510             report = black.Report(verbose=True)
1511             normalized_path = black.normalize_path_maybe_ignore(
1512                 junction_dir, root, report
1513             )
1514             # Manually delete for Python < 3.8
1515             os.system(f"rmdir {junction_dir}")
1516
1517             self.assertEqual(normalized_path, None)
1518
1519     def test_newline_comment_interaction(self) -> None:
1520         source = "class A:\\\r\n# type: ignore\n pass\n"
1521         output = black.format_str(source, mode=DEFAULT_MODE)
1522         black.assert_stable(source, output, mode=DEFAULT_MODE)
1523
1524     def test_bpo_2142_workaround(self) -> None:
1525         # https://bugs.python.org/issue2142
1526
1527         source, _ = read_data("miscellaneous", "missing_final_newline")
1528         # read_data adds a trailing newline
1529         source = source.rstrip()
1530         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1531         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1532         diff_header = re.compile(
1533             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1534             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1535         )
1536         try:
1537             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1538             self.assertEqual(result.exit_code, 0)
1539         finally:
1540             os.unlink(tmp_file)
1541         actual = result.output
1542         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1543         self.assertEqual(actual, expected)
1544
1545     @staticmethod
1546     def compare_results(
1547         result: click.testing.Result, expected_value: str, expected_exit_code: int
1548     ) -> None:
1549         """Helper method to test the value and exit code of a click Result."""
1550         assert (
1551             result.output == expected_value
1552         ), "The output did not match the expected value."
1553         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1554
1555     def test_code_option(self) -> None:
1556         """Test the code option with no changes."""
1557         code = 'print("Hello world")\n'
1558         args = ["--code", code]
1559         result = CliRunner().invoke(black.main, args)
1560
1561         self.compare_results(result, code, 0)
1562
1563     def test_code_option_changed(self) -> None:
1564         """Test the code option when changes are required."""
1565         code = "print('hello world')"
1566         formatted = black.format_str(code, mode=DEFAULT_MODE)
1567
1568         args = ["--code", code]
1569         result = CliRunner().invoke(black.main, args)
1570
1571         self.compare_results(result, formatted, 0)
1572
1573     def test_code_option_check(self) -> None:
1574         """Test the code option when check is passed."""
1575         args = ["--check", "--code", 'print("Hello world")\n']
1576         result = CliRunner().invoke(black.main, args)
1577         self.compare_results(result, "", 0)
1578
1579     def test_code_option_check_changed(self) -> None:
1580         """Test the code option when changes are required, and check is passed."""
1581         args = ["--check", "--code", "print('hello world')"]
1582         result = CliRunner().invoke(black.main, args)
1583         self.compare_results(result, "", 1)
1584
1585     def test_code_option_diff(self) -> None:
1586         """Test the code option when diff is passed."""
1587         code = "print('hello world')"
1588         formatted = black.format_str(code, mode=DEFAULT_MODE)
1589         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1590
1591         args = ["--diff", "--code", code]
1592         result = CliRunner().invoke(black.main, args)
1593
1594         # Remove time from diff
1595         output = DIFF_TIME.sub("", result.output)
1596
1597         assert output == result_diff, "The output did not match the expected value."
1598         assert result.exit_code == 0, "The exit code is incorrect."
1599
1600     def test_code_option_color_diff(self) -> None:
1601         """Test the code option when color and diff are passed."""
1602         code = "print('hello world')"
1603         formatted = black.format_str(code, mode=DEFAULT_MODE)
1604
1605         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1606         result_diff = color_diff(result_diff)
1607
1608         args = ["--diff", "--color", "--code", code]
1609         result = CliRunner().invoke(black.main, args)
1610
1611         # Remove time from diff
1612         output = DIFF_TIME.sub("", result.output)
1613
1614         assert output == result_diff, "The output did not match the expected value."
1615         assert result.exit_code == 0, "The exit code is incorrect."
1616
1617     @pytest.mark.incompatible_with_mypyc
1618     def test_code_option_safe(self) -> None:
1619         """Test that the code option throws an error when the sanity checks fail."""
1620         # Patch black.assert_equivalent to ensure the sanity checks fail
1621         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1622             code = 'print("Hello world")'
1623             error_msg = f"{code}\nerror: cannot format <string>: \n"
1624
1625             args = ["--safe", "--code", code]
1626             result = CliRunner().invoke(black.main, args)
1627
1628             self.compare_results(result, error_msg, 123)
1629
1630     def test_code_option_fast(self) -> None:
1631         """Test that the code option ignores errors when the sanity checks fail."""
1632         # Patch black.assert_equivalent to ensure the sanity checks fail
1633         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1634             code = 'print("Hello world")'
1635             formatted = black.format_str(code, mode=DEFAULT_MODE)
1636
1637             args = ["--fast", "--code", code]
1638             result = CliRunner().invoke(black.main, args)
1639
1640             self.compare_results(result, formatted, 0)
1641
1642     @pytest.mark.incompatible_with_mypyc
1643     def test_code_option_config(self) -> None:
1644         """
1645         Test that the code option finds the pyproject.toml in the current directory.
1646         """
1647         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1648             args = ["--code", "print"]
1649             # This is the only directory known to contain a pyproject.toml
1650             with change_directory(PROJECT_ROOT):
1651                 CliRunner().invoke(black.main, args)
1652                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1653
1654             assert (
1655                 len(parse.mock_calls) >= 1
1656             ), "Expected config parse to be called with the current directory."
1657
1658             _, call_args, _ = parse.mock_calls[0]
1659             assert (
1660                 call_args[0].lower() == str(pyproject_path).lower()
1661             ), "Incorrect config loaded."
1662
1663     @pytest.mark.incompatible_with_mypyc
1664     def test_code_option_parent_config(self) -> None:
1665         """
1666         Test that the code option finds the pyproject.toml in the parent directory.
1667         """
1668         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1669             with change_directory(THIS_DIR):
1670                 args = ["--code", "print"]
1671                 CliRunner().invoke(black.main, args)
1672
1673                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1674                 assert (
1675                     len(parse.mock_calls) >= 1
1676                 ), "Expected config parse to be called with the current directory."
1677
1678                 _, call_args, _ = parse.mock_calls[0]
1679                 assert (
1680                     call_args[0].lower() == str(pyproject_path).lower()
1681                 ), "Incorrect config loaded."
1682
1683     def test_for_handled_unexpected_eof_error(self) -> None:
1684         """
1685         Test that an unexpected EOF SyntaxError is nicely presented.
1686         """
1687         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1688             black.lib2to3_parse("print(", {})
1689
1690         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1691
1692     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1693         with pytest.raises(AssertionError) as err:
1694             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1695
1696         err.match("--safe")
1697         # Unfortunately the SyntaxError message has changed in newer versions so we
1698         # can't match it directly.
1699         err.match("invalid character")
1700         err.match(r"\(<unknown>, line 1\)")
1701
1702
1703 class TestCaching:
1704     def test_get_cache_dir(
1705         self,
1706         tmp_path: Path,
1707         monkeypatch: pytest.MonkeyPatch,
1708     ) -> None:
1709         # Create multiple cache directories
1710         workspace1 = tmp_path / "ws1"
1711         workspace1.mkdir()
1712         workspace2 = tmp_path / "ws2"
1713         workspace2.mkdir()
1714
1715         # Force user_cache_dir to use the temporary directory for easier assertions
1716         patch_user_cache_dir = patch(
1717             target="black.cache.user_cache_dir",
1718             autospec=True,
1719             return_value=str(workspace1),
1720         )
1721
1722         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1723         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1724         with patch_user_cache_dir:
1725             assert get_cache_dir() == workspace1
1726
1727         # If it is set, use the path provided in the env var.
1728         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1729         assert get_cache_dir() == workspace2
1730
1731     def test_cache_broken_file(self) -> None:
1732         mode = DEFAULT_MODE
1733         with cache_dir() as workspace:
1734             cache_file = get_cache_file(mode)
1735             cache_file.write_text("this is not a pickle")
1736             assert black.read_cache(mode) == {}
1737             src = (workspace / "test.py").resolve()
1738             src.write_text("print('hello')")
1739             invokeBlack([str(src)])
1740             cache = black.read_cache(mode)
1741             assert str(src) in cache
1742
1743     def test_cache_single_file_already_cached(self) -> None:
1744         mode = DEFAULT_MODE
1745         with cache_dir() as workspace:
1746             src = (workspace / "test.py").resolve()
1747             src.write_text("print('hello')")
1748             black.write_cache({}, [src], mode)
1749             invokeBlack([str(src)])
1750             assert src.read_text() == "print('hello')"
1751
1752     @event_loop()
1753     def test_cache_multiple_files(self) -> None:
1754         mode = DEFAULT_MODE
1755         with cache_dir() as workspace, patch(
1756             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1757         ):
1758             one = (workspace / "one.py").resolve()
1759             with one.open("w") as fobj:
1760                 fobj.write("print('hello')")
1761             two = (workspace / "two.py").resolve()
1762             with two.open("w") as fobj:
1763                 fobj.write("print('hello')")
1764             black.write_cache({}, [one], mode)
1765             invokeBlack([str(workspace)])
1766             with one.open("r") as fobj:
1767                 assert fobj.read() == "print('hello')"
1768             with two.open("r") as fobj:
1769                 assert fobj.read() == 'print("hello")\n'
1770             cache = black.read_cache(mode)
1771             assert str(one) in cache
1772             assert str(two) in cache
1773
1774     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1775     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1776         mode = DEFAULT_MODE
1777         with cache_dir() as workspace:
1778             src = (workspace / "test.py").resolve()
1779             with src.open("w") as fobj:
1780                 fobj.write("print('hello')")
1781             with patch("black.read_cache") as read_cache, patch(
1782                 "black.write_cache"
1783             ) as write_cache:
1784                 cmd = [str(src), "--diff"]
1785                 if color:
1786                     cmd.append("--color")
1787                 invokeBlack(cmd)
1788                 cache_file = get_cache_file(mode)
1789                 assert cache_file.exists() is False
1790                 write_cache.assert_not_called()
1791                 read_cache.assert_not_called()
1792
1793     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1794     @event_loop()
1795     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1796         with cache_dir() as workspace:
1797             for tag in range(0, 4):
1798                 src = (workspace / f"test{tag}.py").resolve()
1799                 with src.open("w") as fobj:
1800                     fobj.write("print('hello')")
1801             with patch(
1802                 "black.concurrency.Manager", wraps=multiprocessing.Manager
1803             ) as mgr:
1804                 cmd = ["--diff", str(workspace)]
1805                 if color:
1806                     cmd.append("--color")
1807                 invokeBlack(cmd, exit_code=0)
1808                 # this isn't quite doing what we want, but if it _isn't_
1809                 # called then we cannot be using the lock it provides
1810                 mgr.assert_called()
1811
1812     def test_no_cache_when_stdin(self) -> None:
1813         mode = DEFAULT_MODE
1814         with cache_dir():
1815             result = CliRunner().invoke(
1816                 black.main, ["-"], input=BytesIO(b"print('hello')")
1817             )
1818             assert not result.exit_code
1819             cache_file = get_cache_file(mode)
1820             assert not cache_file.exists()
1821
1822     def test_read_cache_no_cachefile(self) -> None:
1823         mode = DEFAULT_MODE
1824         with cache_dir():
1825             assert black.read_cache(mode) == {}
1826
1827     def test_write_cache_read_cache(self) -> None:
1828         mode = DEFAULT_MODE
1829         with cache_dir() as workspace:
1830             src = (workspace / "test.py").resolve()
1831             src.touch()
1832             black.write_cache({}, [src], mode)
1833             cache = black.read_cache(mode)
1834             assert str(src) in cache
1835             assert cache[str(src)] == black.get_cache_info(src)
1836
1837     def test_filter_cached(self) -> None:
1838         with TemporaryDirectory() as workspace:
1839             path = Path(workspace)
1840             uncached = (path / "uncached").resolve()
1841             cached = (path / "cached").resolve()
1842             cached_but_changed = (path / "changed").resolve()
1843             uncached.touch()
1844             cached.touch()
1845             cached_but_changed.touch()
1846             cache = {
1847                 str(cached): black.get_cache_info(cached),
1848                 str(cached_but_changed): (0.0, 0),
1849             }
1850             todo, done = black.cache.filter_cached(
1851                 cache, {uncached, cached, cached_but_changed}
1852             )
1853             assert todo == {uncached, cached_but_changed}
1854             assert done == {cached}
1855
1856     def test_write_cache_creates_directory_if_needed(self) -> None:
1857         mode = DEFAULT_MODE
1858         with cache_dir(exists=False) as workspace:
1859             assert not workspace.exists()
1860             black.write_cache({}, [], mode)
1861             assert workspace.exists()
1862
1863     @event_loop()
1864     def test_failed_formatting_does_not_get_cached(self) -> None:
1865         mode = DEFAULT_MODE
1866         with cache_dir() as workspace, patch(
1867             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1868         ):
1869             failing = (workspace / "failing.py").resolve()
1870             with failing.open("w") as fobj:
1871                 fobj.write("not actually python")
1872             clean = (workspace / "clean.py").resolve()
1873             with clean.open("w") as fobj:
1874                 fobj.write('print("hello")\n')
1875             invokeBlack([str(workspace)], exit_code=123)
1876             cache = black.read_cache(mode)
1877             assert str(failing) not in cache
1878             assert str(clean) in cache
1879
1880     def test_write_cache_write_fail(self) -> None:
1881         mode = DEFAULT_MODE
1882         with cache_dir(), patch.object(Path, "open") as mock:
1883             mock.side_effect = OSError
1884             black.write_cache({}, [], mode)
1885
1886     def test_read_cache_line_lengths(self) -> None:
1887         mode = DEFAULT_MODE
1888         short_mode = replace(DEFAULT_MODE, line_length=1)
1889         with cache_dir() as workspace:
1890             path = (workspace / "file.py").resolve()
1891             path.touch()
1892             black.write_cache({}, [path], mode)
1893             one = black.read_cache(mode)
1894             assert str(path) in one
1895             two = black.read_cache(short_mode)
1896             assert str(path) not in two
1897
1898
1899 def assert_collected_sources(
1900     src: Sequence[Union[str, Path]],
1901     expected: Sequence[Union[str, Path]],
1902     *,
1903     ctx: Optional[FakeContext] = None,
1904     exclude: Optional[str] = None,
1905     include: Optional[str] = None,
1906     extend_exclude: Optional[str] = None,
1907     force_exclude: Optional[str] = None,
1908     stdin_filename: Optional[str] = None,
1909 ) -> None:
1910     gs_src = tuple(str(Path(s)) for s in src)
1911     gs_expected = [Path(s) for s in expected]
1912     gs_exclude = None if exclude is None else compile_pattern(exclude)
1913     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1914     gs_extend_exclude = (
1915         None if extend_exclude is None else compile_pattern(extend_exclude)
1916     )
1917     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1918     collected = black.get_sources(
1919         ctx=ctx or FakeContext(),
1920         src=gs_src,
1921         quiet=False,
1922         verbose=False,
1923         include=gs_include,
1924         exclude=gs_exclude,
1925         extend_exclude=gs_extend_exclude,
1926         force_exclude=gs_force_exclude,
1927         report=black.Report(),
1928         stdin_filename=stdin_filename,
1929     )
1930     assert sorted(collected) == sorted(gs_expected)
1931
1932
1933 class TestFileCollection:
1934     def test_include_exclude(self) -> None:
1935         path = THIS_DIR / "data" / "include_exclude_tests"
1936         src = [path]
1937         expected = [
1938             Path(path / "b/dont_exclude/a.py"),
1939             Path(path / "b/dont_exclude/a.pyi"),
1940         ]
1941         assert_collected_sources(
1942             src,
1943             expected,
1944             include=r"\.pyi?$",
1945             exclude=r"/exclude/|/\.definitely_exclude/",
1946         )
1947
1948     def test_gitignore_used_as_default(self) -> None:
1949         base = Path(DATA_DIR / "include_exclude_tests")
1950         expected = [
1951             base / "b/.definitely_exclude/a.py",
1952             base / "b/.definitely_exclude/a.pyi",
1953         ]
1954         src = [base / "b/"]
1955         ctx = FakeContext()
1956         ctx.obj["root"] = base
1957         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
1958
1959     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
1960     def test_exclude_for_issue_1572(self) -> None:
1961         # Exclude shouldn't touch files that were explicitly given to Black through the
1962         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1963         # https://github.com/psf/black/issues/1572
1964         path = DATA_DIR / "include_exclude_tests"
1965         src = [path / "b/exclude/a.py"]
1966         expected = [path / "b/exclude/a.py"]
1967         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1968
1969     def test_gitignore_exclude(self) -> None:
1970         path = THIS_DIR / "data" / "include_exclude_tests"
1971         include = re.compile(r"\.pyi?$")
1972         exclude = re.compile(r"")
1973         report = black.Report()
1974         gitignore = PathSpec.from_lines(
1975             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1976         )
1977         sources: List[Path] = []
1978         expected = [
1979             Path(path / "b/dont_exclude/a.py"),
1980             Path(path / "b/dont_exclude/a.pyi"),
1981         ]
1982         this_abs = THIS_DIR.resolve()
1983         sources.extend(
1984             black.gen_python_files(
1985                 path.iterdir(),
1986                 this_abs,
1987                 include,
1988                 exclude,
1989                 None,
1990                 None,
1991                 report,
1992                 gitignore,
1993                 verbose=False,
1994                 quiet=False,
1995             )
1996         )
1997         assert sorted(expected) == sorted(sources)
1998
1999     def test_nested_gitignore(self) -> None:
2000         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
2001         include = re.compile(r"\.pyi?$")
2002         exclude = re.compile(r"")
2003         root_gitignore = black.files.get_gitignore(path)
2004         report = black.Report()
2005         expected: List[Path] = [
2006             Path(path / "x.py"),
2007             Path(path / "root/b.py"),
2008             Path(path / "root/c.py"),
2009             Path(path / "root/child/c.py"),
2010         ]
2011         this_abs = THIS_DIR.resolve()
2012         sources = list(
2013             black.gen_python_files(
2014                 path.iterdir(),
2015                 this_abs,
2016                 include,
2017                 exclude,
2018                 None,
2019                 None,
2020                 report,
2021                 root_gitignore,
2022                 verbose=False,
2023                 quiet=False,
2024             )
2025         )
2026         assert sorted(expected) == sorted(sources)
2027
2028     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2029         # https://github.com/psf/black/issues/2598
2030         path = Path(DATA_DIR / "nested_gitignore_tests")
2031         src = Path(path / "root" / "child")
2032         expected = [src / "a.py", src / "c.py"]
2033         assert_collected_sources([src], expected)
2034
2035     def test_invalid_gitignore(self) -> None:
2036         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2037         empty_config = path / "pyproject.toml"
2038         result = BlackRunner().invoke(
2039             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2040         )
2041         assert result.exit_code == 1
2042         assert result.stderr_bytes is not None
2043
2044         gitignore = path / ".gitignore"
2045         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2046
2047     def test_invalid_nested_gitignore(self) -> None:
2048         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2049         empty_config = path / "pyproject.toml"
2050         result = BlackRunner().invoke(
2051             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2052         )
2053         assert result.exit_code == 1
2054         assert result.stderr_bytes is not None
2055
2056         gitignore = path / "a" / ".gitignore"
2057         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2058
2059     def test_empty_include(self) -> None:
2060         path = DATA_DIR / "include_exclude_tests"
2061         src = [path]
2062         expected = [
2063             Path(path / "b/exclude/a.pie"),
2064             Path(path / "b/exclude/a.py"),
2065             Path(path / "b/exclude/a.pyi"),
2066             Path(path / "b/dont_exclude/a.pie"),
2067             Path(path / "b/dont_exclude/a.py"),
2068             Path(path / "b/dont_exclude/a.pyi"),
2069             Path(path / "b/.definitely_exclude/a.pie"),
2070             Path(path / "b/.definitely_exclude/a.py"),
2071             Path(path / "b/.definitely_exclude/a.pyi"),
2072             Path(path / ".gitignore"),
2073             Path(path / "pyproject.toml"),
2074         ]
2075         # Setting exclude explicitly to an empty string to block .gitignore usage.
2076         assert_collected_sources(src, expected, include="", exclude="")
2077
2078     def test_extend_exclude(self) -> None:
2079         path = DATA_DIR / "include_exclude_tests"
2080         src = [path]
2081         expected = [
2082             Path(path / "b/exclude/a.py"),
2083             Path(path / "b/dont_exclude/a.py"),
2084         ]
2085         assert_collected_sources(
2086             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2087         )
2088
2089     @pytest.mark.incompatible_with_mypyc
2090     def test_symlink_out_of_root_directory(self) -> None:
2091         path = MagicMock()
2092         root = THIS_DIR.resolve()
2093         child = MagicMock()
2094         include = re.compile(black.DEFAULT_INCLUDES)
2095         exclude = re.compile(black.DEFAULT_EXCLUDES)
2096         report = black.Report()
2097         gitignore = PathSpec.from_lines("gitwildmatch", [])
2098         # `child` should behave like a symlink which resolved path is clearly
2099         # outside of the `root` directory.
2100         path.iterdir.return_value = [child]
2101         child.resolve.return_value = Path("/a/b/c")
2102         child.as_posix.return_value = "/a/b/c"
2103         try:
2104             list(
2105                 black.gen_python_files(
2106                     path.iterdir(),
2107                     root,
2108                     include,
2109                     exclude,
2110                     None,
2111                     None,
2112                     report,
2113                     gitignore,
2114                     verbose=False,
2115                     quiet=False,
2116                 )
2117             )
2118         except ValueError as ve:
2119             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2120         path.iterdir.assert_called_once()
2121         child.resolve.assert_called_once()
2122
2123     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2124     def test_get_sources_with_stdin(self) -> None:
2125         src = ["-"]
2126         expected = ["-"]
2127         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2128
2129     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2130     def test_get_sources_with_stdin_filename(self) -> None:
2131         src = ["-"]
2132         stdin_filename = str(THIS_DIR / "data/collections.py")
2133         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2134         assert_collected_sources(
2135             src,
2136             expected,
2137             exclude=r"/exclude/a\.py",
2138             stdin_filename=stdin_filename,
2139         )
2140
2141     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2142     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2143         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2144         # file being passed directly. This is the same as
2145         # test_exclude_for_issue_1572
2146         path = DATA_DIR / "include_exclude_tests"
2147         src = ["-"]
2148         stdin_filename = str(path / "b/exclude/a.py")
2149         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2150         assert_collected_sources(
2151             src,
2152             expected,
2153             exclude=r"/exclude/|a\.py",
2154             stdin_filename=stdin_filename,
2155         )
2156
2157     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2158     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2159         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2160         # file being passed directly. This is the same as
2161         # test_exclude_for_issue_1572
2162         src = ["-"]
2163         path = THIS_DIR / "data" / "include_exclude_tests"
2164         stdin_filename = str(path / "b/exclude/a.py")
2165         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2166         assert_collected_sources(
2167             src,
2168             expected,
2169             extend_exclude=r"/exclude/|a\.py",
2170             stdin_filename=stdin_filename,
2171         )
2172
2173     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2174     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2175         # Force exclude should exclude the file when passing it through
2176         # stdin_filename
2177         path = THIS_DIR / "data" / "include_exclude_tests"
2178         stdin_filename = str(path / "b/exclude/a.py")
2179         assert_collected_sources(
2180             src=["-"],
2181             expected=[],
2182             force_exclude=r"/exclude/|a\.py",
2183             stdin_filename=stdin_filename,
2184         )
2185
2186
2187 try:
2188     with open(black.__file__, "r", encoding="utf-8") as _bf:
2189         black_source_lines = _bf.readlines()
2190 except UnicodeDecodeError:
2191     if not black.COMPILED:
2192         raise
2193
2194
2195 def tracefunc(
2196     frame: types.FrameType, event: str, arg: Any
2197 ) -> Callable[[types.FrameType, str, Any], Any]:
2198     """Show function calls `from black/__init__.py` as they happen.
2199
2200     Register this with `sys.settrace()` in a test you're debugging.
2201     """
2202     if event != "call":
2203         return tracefunc
2204
2205     stack = len(inspect.stack()) - 19
2206     stack *= 2
2207     filename = frame.f_code.co_filename
2208     lineno = frame.f_lineno
2209     func_sig_lineno = lineno - 1
2210     funcname = black_source_lines[func_sig_lineno].strip()
2211     while funcname.startswith("@"):
2212         func_sig_lineno += 1
2213         funcname = black_source_lines[func_sig_lineno].strip()
2214     if "black/__init__.py" in filename:
2215         print(f"{' ' * stack}{lineno}:{funcname}")
2216     return tracefunc