]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Correctly handle trailing commas that are inside a line's leading non-nested parens...
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import re
10 import sys
11 import types
12 import unittest
13 from concurrent.futures import ThreadPoolExecutor
14 from contextlib import contextmanager, redirect_stderr
15 from dataclasses import replace
16 from io import BytesIO
17 from pathlib import Path
18 from platform import system
19 from tempfile import TemporaryDirectory
20 from typing import (
21     Any,
22     Callable,
23     Dict,
24     Iterator,
25     List,
26     Optional,
27     Sequence,
28     TypeVar,
29     Union,
30 )
31 from unittest.mock import MagicMock, patch
32
33 import click
34 import pytest
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_dir, get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     get_case_path,
63     read_data,
64     read_data_from_file,
65 )
66
67 THIS_FILE = Path(__file__)
68 EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
69 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
70 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
71 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
72 T = TypeVar("T")
73 R = TypeVar("R")
74
75 # Match the time output in a diff, but nothing else
76 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
77
78
79 @contextmanager
80 def cache_dir(exists: bool = True) -> Iterator[Path]:
81     with TemporaryDirectory() as workspace:
82         cache_dir = Path(workspace)
83         if not exists:
84             cache_dir = cache_dir / "new"
85         with patch("black.cache.CACHE_DIR", cache_dir):
86             yield cache_dir
87
88
89 @contextmanager
90 def event_loop() -> Iterator[None]:
91     policy = asyncio.get_event_loop_policy()
92     loop = policy.new_event_loop()
93     asyncio.set_event_loop(loop)
94     try:
95         yield
96
97     finally:
98         loop.close()
99
100
101 class FakeContext(click.Context):
102     """A fake click Context for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         self.default_map: Dict[str, Any] = {}
106         # Dummy root, since most of the tests don't care about it
107         self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
108
109
110 class FakeParameter(click.Parameter):
111     """A fake click Parameter for when calling functions that need it."""
112
113     def __init__(self) -> None:
114         pass
115
116
117 class BlackRunner(CliRunner):
118     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
119
120     def __init__(self) -> None:
121         super().__init__(mix_stderr=False)
122
123
124 def invokeBlack(
125     args: List[str], exit_code: int = 0, ignore_config: bool = True
126 ) -> None:
127     runner = BlackRunner()
128     if ignore_config:
129         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
130     result = runner.invoke(black.main, args, catch_exceptions=False)
131     assert result.stdout_bytes is not None
132     assert result.stderr_bytes is not None
133     msg = (
134         f"Failed with args: {args}\n"
135         f"stdout: {result.stdout_bytes.decode()!r}\n"
136         f"stderr: {result.stderr_bytes.decode()!r}\n"
137         f"exception: {result.exception}"
138     )
139     assert result.exit_code == exit_code, msg
140
141
142 class BlackTestCase(BlackBaseTestCase):
143     invokeBlack = staticmethod(invokeBlack)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_experimental_string_processing_warns(self) -> None:
157         self.assertWarns(
158             black.mode.Deprecated, black.Mode, experimental_string_processing=True
159         )
160
161     def test_piping(self) -> None:
162         source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py")
163         result = BlackRunner().invoke(
164             black.main,
165             [
166                 "-",
167                 "--fast",
168                 f"--line-length={black.DEFAULT_LINE_LENGTH}",
169                 f"--config={EMPTY_CONFIG}",
170             ],
171             input=BytesIO(source.encode("utf8")),
172         )
173         self.assertEqual(result.exit_code, 0)
174         self.assertFormatEqual(expected, result.output)
175         if source != result.output:
176             black.assert_equivalent(source, result.output)
177             black.assert_stable(source, result.output, DEFAULT_MODE)
178
179     def test_piping_diff(self) -> None:
180         diff_header = re.compile(
181             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
182             r"\+\d\d\d\d"
183         )
184         source, _ = read_data("simple_cases", "expression.py")
185         expected, _ = read_data("simple_cases", "expression.diff")
186         args = [
187             "-",
188             "--fast",
189             f"--line-length={black.DEFAULT_LINE_LENGTH}",
190             "--diff",
191             f"--config={EMPTY_CONFIG}",
192         ]
193         result = BlackRunner().invoke(
194             black.main, args, input=BytesIO(source.encode("utf8"))
195         )
196         self.assertEqual(result.exit_code, 0)
197         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
198         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
199         self.assertEqual(expected, actual)
200
201     def test_piping_diff_with_color(self) -> None:
202         source, _ = read_data("simple_cases", "expression.py")
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={EMPTY_CONFIG}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("miscellaneous", "wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     def test_pep_572_version_detection(self) -> None:
238         source, _ = read_data("py_38", "pep_572")
239         root = black.lib2to3_parse(source)
240         features = black.get_features_used(root)
241         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
242         versions = black.detect_target_versions(root)
243         self.assertIn(black.TargetVersion.PY38, versions)
244
245     def test_expression_ff(self) -> None:
246         source, expected = read_data("simple_cases", "expression.py")
247         tmp_file = Path(black.dump_to_file(source))
248         try:
249             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
250             with open(tmp_file, encoding="utf8") as f:
251                 actual = f.read()
252         finally:
253             os.unlink(tmp_file)
254         self.assertFormatEqual(expected, actual)
255         with patch("black.dump_to_file", dump_to_stderr):
256             black.assert_equivalent(source, actual)
257             black.assert_stable(source, actual, DEFAULT_MODE)
258
259     def test_expression_diff(self) -> None:
260         source, _ = read_data("simple_cases", "expression.py")
261         expected, _ = read_data("simple_cases", "expression.diff")
262         tmp_file = Path(black.dump_to_file(source))
263         diff_header = re.compile(
264             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
265             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
266         )
267         try:
268             result = BlackRunner().invoke(
269                 black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
270             )
271             self.assertEqual(result.exit_code, 0)
272         finally:
273             os.unlink(tmp_file)
274         actual = result.output
275         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
276         if expected != actual:
277             dump = black.dump_to_file(actual)
278             msg = (
279                 "Expected diff isn't equal to the actual. If you made changes to"
280                 " expression.py and this is an anticipated difference, overwrite"
281                 f" tests/data/expression.diff with {dump}"
282             )
283             self.assertEqual(expected, actual, msg)
284
285     def test_expression_diff_with_color(self) -> None:
286         source, _ = read_data("simple_cases", "expression.py")
287         expected, _ = read_data("simple_cases", "expression.diff")
288         tmp_file = Path(black.dump_to_file(source))
289         try:
290             result = BlackRunner().invoke(
291                 black.main,
292                 ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
293             )
294         finally:
295             os.unlink(tmp_file)
296         actual = result.output
297         # We check the contents of the diff in `test_expression_diff`. All
298         # we need to check here is that color codes exist in the result.
299         self.assertIn("\033[1m", actual)
300         self.assertIn("\033[36m", actual)
301         self.assertIn("\033[32m", actual)
302         self.assertIn("\033[31m", actual)
303         self.assertIn("\033[0m", actual)
304
305     def test_detect_pos_only_arguments(self) -> None:
306         source, _ = read_data("py_38", "pep_570")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_detect_debug_f_strings(self) -> None:
314         root = black.lib2to3_parse("""f"{x=}" """)
315         features = black.get_features_used(root)
316         self.assertIn(black.Feature.DEBUG_F_STRINGS, features)
317         versions = black.detect_target_versions(root)
318         self.assertIn(black.TargetVersion.PY38, versions)
319
320         root = black.lib2to3_parse(
321             """f"{x}"\nf'{"="}'\nf'{(x:=5)}'\nf'{f(a="3=")}'\nf'{x:=10}'\n"""
322         )
323         features = black.get_features_used(root)
324         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
325
326         # We don't yet support feature version detection in nested f-strings
327         root = black.lib2to3_parse(
328             """f"heard a rumour that { f'{1+1=}' } ... seems like it could be true" """
329         )
330         features = black.get_features_used(root)
331         self.assertNotIn(black.Feature.DEBUG_F_STRINGS, features)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_string_quotes(self) -> None:
335         source, expected = read_data("miscellaneous", "string_quotes")
336         mode = black.Mode(preview=True)
337         assert_format(source, expected, mode)
338         mode = replace(mode, string_normalization=False)
339         not_normalized = fs(source, mode=mode)
340         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
341         black.assert_equivalent(source, not_normalized)
342         black.assert_stable(source, not_normalized, mode=mode)
343
344     def test_skip_source_first_line(self) -> None:
345         source, _ = read_data("miscellaneous", "invalid_header")
346         tmp_file = Path(black.dump_to_file(source))
347         # Full source should fail (invalid syntax at header)
348         self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
349         # So, skipping the first line should work
350         result = BlackRunner().invoke(
351             black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
352         )
353         self.assertEqual(result.exit_code, 0)
354         with open(tmp_file, encoding="utf8") as f:
355             actual = f.read()
356         self.assertFormatEqual(source, actual)
357
358     def test_skip_source_first_line_when_mixing_newlines(self) -> None:
359         code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
360         expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
361         with TemporaryDirectory() as workspace:
362             test_file = Path(workspace) / "skip_header.py"
363             test_file.write_bytes(code_mixing_newlines)
364             mode = replace(DEFAULT_MODE, skip_source_first_line=True)
365             ff(test_file, mode=mode, write_back=black.WriteBack.YES)
366             self.assertEqual(test_file.read_bytes(), expected)
367
368     def test_skip_magic_trailing_comma(self) -> None:
369         source, _ = read_data("simple_cases", "expression")
370         expected, _ = read_data(
371             "miscellaneous", "expression_skip_magic_trailing_comma.diff"
372         )
373         tmp_file = Path(black.dump_to_file(source))
374         diff_header = re.compile(
375             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
376             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
377         )
378         try:
379             result = BlackRunner().invoke(
380                 black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
381             )
382             self.assertEqual(result.exit_code, 0)
383         finally:
384             os.unlink(tmp_file)
385         actual = result.output
386         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
387         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
388         if expected != actual:
389             dump = black.dump_to_file(actual)
390             msg = (
391                 "Expected diff isn't equal to the actual. If you made changes to"
392                 " expression.py and this is an anticipated difference, overwrite"
393                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
394             )
395             self.assertEqual(expected, actual, msg)
396
397     @patch("black.dump_to_file", dump_to_stderr)
398     def test_async_as_identifier(self) -> None:
399         source_path = get_case_path("miscellaneous", "async_as_identifier")
400         source, expected = read_data_from_file(source_path)
401         actual = fs(source)
402         self.assertFormatEqual(expected, actual)
403         major, minor = sys.version_info[:2]
404         if major < 3 or (major <= 3 and minor < 7):
405             black.assert_equivalent(source, actual)
406         black.assert_stable(source, actual, DEFAULT_MODE)
407         # ensure black can parse this when the target is 3.6
408         self.invokeBlack([str(source_path), "--target-version", "py36"])
409         # but not on 3.7, because async/await is no longer an identifier
410         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_python37(self) -> None:
414         source_path = get_case_path("py_37", "python37")
415         source, expected = read_data_from_file(source_path)
416         actual = fs(source)
417         self.assertFormatEqual(expected, actual)
418         major, minor = sys.version_info[:2]
419         if major > 3 or (major == 3 and minor >= 7):
420             black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, DEFAULT_MODE)
422         # ensure black can parse this when the target is 3.7
423         self.invokeBlack([str(source_path), "--target-version", "py37"])
424         # but not on 3.6, because we use async as a reserved keyword
425         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
426
427     def test_tab_comment_indentation(self) -> None:
428         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
429         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
430         self.assertFormatEqual(contents_spc, fs(contents_spc))
431         self.assertFormatEqual(contents_spc, fs(contents_tab))
432
433         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
434         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
435         self.assertFormatEqual(contents_spc, fs(contents_spc))
436         self.assertFormatEqual(contents_spc, fs(contents_tab))
437
438         # mixed tabs and spaces (valid Python 2 code)
439         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
440         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
441         self.assertFormatEqual(contents_spc, fs(contents_spc))
442         self.assertFormatEqual(contents_spc, fs(contents_tab))
443
444         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
445         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
446         self.assertFormatEqual(contents_spc, fs(contents_spc))
447         self.assertFormatEqual(contents_spc, fs(contents_tab))
448
449     def test_report_verbose(self) -> None:
450         report = Report(verbose=True)
451         out_lines = []
452         err_lines = []
453
454         def out(msg: str, **kwargs: Any) -> None:
455             out_lines.append(msg)
456
457         def err(msg: str, **kwargs: Any) -> None:
458             err_lines.append(msg)
459
460         with patch("black.output._out", out), patch("black.output._err", err):
461             report.done(Path("f1"), black.Changed.NO)
462             self.assertEqual(len(out_lines), 1)
463             self.assertEqual(len(err_lines), 0)
464             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
465             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
466             self.assertEqual(report.return_code, 0)
467             report.done(Path("f2"), black.Changed.YES)
468             self.assertEqual(len(out_lines), 2)
469             self.assertEqual(len(err_lines), 0)
470             self.assertEqual(out_lines[-1], "reformatted f2")
471             self.assertEqual(
472                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
473             )
474             report.done(Path("f3"), black.Changed.CACHED)
475             self.assertEqual(len(out_lines), 3)
476             self.assertEqual(len(err_lines), 0)
477             self.assertEqual(
478                 out_lines[-1], "f3 wasn't modified on disk since last run."
479             )
480             self.assertEqual(
481                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
482             )
483             self.assertEqual(report.return_code, 0)
484             report.check = True
485             self.assertEqual(report.return_code, 1)
486             report.check = False
487             report.failed(Path("e1"), "boom")
488             self.assertEqual(len(out_lines), 3)
489             self.assertEqual(len(err_lines), 1)
490             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
491             self.assertEqual(
492                 unstyle(str(report)),
493                 (
494                     "1 file reformatted, 2 files left unchanged, 1 file failed to"
495                     " reformat."
496                 ),
497             )
498             self.assertEqual(report.return_code, 123)
499             report.done(Path("f3"), black.Changed.YES)
500             self.assertEqual(len(out_lines), 4)
501             self.assertEqual(len(err_lines), 1)
502             self.assertEqual(out_lines[-1], "reformatted f3")
503             self.assertEqual(
504                 unstyle(str(report)),
505                 (
506                     "2 files reformatted, 2 files left unchanged, 1 file failed to"
507                     " reformat."
508                 ),
509             )
510             self.assertEqual(report.return_code, 123)
511             report.failed(Path("e2"), "boom")
512             self.assertEqual(len(out_lines), 4)
513             self.assertEqual(len(err_lines), 2)
514             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
515             self.assertEqual(
516                 unstyle(str(report)),
517                 (
518                     "2 files reformatted, 2 files left unchanged, 2 files failed to"
519                     " reformat."
520                 ),
521             )
522             self.assertEqual(report.return_code, 123)
523             report.path_ignored(Path("wat"), "no match")
524             self.assertEqual(len(out_lines), 5)
525             self.assertEqual(len(err_lines), 2)
526             self.assertEqual(out_lines[-1], "wat ignored: no match")
527             self.assertEqual(
528                 unstyle(str(report)),
529                 (
530                     "2 files reformatted, 2 files left unchanged, 2 files failed to"
531                     " reformat."
532                 ),
533             )
534             self.assertEqual(report.return_code, 123)
535             report.done(Path("f4"), black.Changed.NO)
536             self.assertEqual(len(out_lines), 6)
537             self.assertEqual(len(err_lines), 2)
538             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
539             self.assertEqual(
540                 unstyle(str(report)),
541                 (
542                     "2 files reformatted, 3 files left unchanged, 2 files failed to"
543                     " reformat."
544                 ),
545             )
546             self.assertEqual(report.return_code, 123)
547             report.check = True
548             self.assertEqual(
549                 unstyle(str(report)),
550                 (
551                     "2 files would be reformatted, 3 files would be left unchanged, 2"
552                     " files would fail to reformat."
553                 ),
554             )
555             report.check = False
556             report.diff = True
557             self.assertEqual(
558                 unstyle(str(report)),
559                 (
560                     "2 files would be reformatted, 3 files would be left unchanged, 2"
561                     " files would fail to reformat."
562                 ),
563             )
564
565     def test_report_quiet(self) -> None:
566         report = Report(quiet=True)
567         out_lines = []
568         err_lines = []
569
570         def out(msg: str, **kwargs: Any) -> None:
571             out_lines.append(msg)
572
573         def err(msg: str, **kwargs: Any) -> None:
574             err_lines.append(msg)
575
576         with patch("black.output._out", out), patch("black.output._err", err):
577             report.done(Path("f1"), black.Changed.NO)
578             self.assertEqual(len(out_lines), 0)
579             self.assertEqual(len(err_lines), 0)
580             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
581             self.assertEqual(report.return_code, 0)
582             report.done(Path("f2"), black.Changed.YES)
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 0)
585             self.assertEqual(
586                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
587             )
588             report.done(Path("f3"), black.Changed.CACHED)
589             self.assertEqual(len(out_lines), 0)
590             self.assertEqual(len(err_lines), 0)
591             self.assertEqual(
592                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
593             )
594             self.assertEqual(report.return_code, 0)
595             report.check = True
596             self.assertEqual(report.return_code, 1)
597             report.check = False
598             report.failed(Path("e1"), "boom")
599             self.assertEqual(len(out_lines), 0)
600             self.assertEqual(len(err_lines), 1)
601             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
602             self.assertEqual(
603                 unstyle(str(report)),
604                 (
605                     "1 file reformatted, 2 files left unchanged, 1 file failed to"
606                     " reformat."
607                 ),
608             )
609             self.assertEqual(report.return_code, 123)
610             report.done(Path("f3"), black.Changed.YES)
611             self.assertEqual(len(out_lines), 0)
612             self.assertEqual(len(err_lines), 1)
613             self.assertEqual(
614                 unstyle(str(report)),
615                 (
616                     "2 files reformatted, 2 files left unchanged, 1 file failed to"
617                     " reformat."
618                 ),
619             )
620             self.assertEqual(report.return_code, 123)
621             report.failed(Path("e2"), "boom")
622             self.assertEqual(len(out_lines), 0)
623             self.assertEqual(len(err_lines), 2)
624             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
625             self.assertEqual(
626                 unstyle(str(report)),
627                 (
628                     "2 files reformatted, 2 files left unchanged, 2 files failed to"
629                     " reformat."
630                 ),
631             )
632             self.assertEqual(report.return_code, 123)
633             report.path_ignored(Path("wat"), "no match")
634             self.assertEqual(len(out_lines), 0)
635             self.assertEqual(len(err_lines), 2)
636             self.assertEqual(
637                 unstyle(str(report)),
638                 (
639                     "2 files reformatted, 2 files left unchanged, 2 files failed to"
640                     " reformat."
641                 ),
642             )
643             self.assertEqual(report.return_code, 123)
644             report.done(Path("f4"), black.Changed.NO)
645             self.assertEqual(len(out_lines), 0)
646             self.assertEqual(len(err_lines), 2)
647             self.assertEqual(
648                 unstyle(str(report)),
649                 (
650                     "2 files reformatted, 3 files left unchanged, 2 files failed to"
651                     " reformat."
652                 ),
653             )
654             self.assertEqual(report.return_code, 123)
655             report.check = True
656             self.assertEqual(
657                 unstyle(str(report)),
658                 (
659                     "2 files would be reformatted, 3 files would be left unchanged, 2"
660                     " files would fail to reformat."
661                 ),
662             )
663             report.check = False
664             report.diff = True
665             self.assertEqual(
666                 unstyle(str(report)),
667                 (
668                     "2 files would be reformatted, 3 files would be left unchanged, 2"
669                     " files would fail to reformat."
670                 ),
671             )
672
673     def test_report_normal(self) -> None:
674         report = black.Report()
675         out_lines = []
676         err_lines = []
677
678         def out(msg: str, **kwargs: Any) -> None:
679             out_lines.append(msg)
680
681         def err(msg: str, **kwargs: Any) -> None:
682             err_lines.append(msg)
683
684         with patch("black.output._out", out), patch("black.output._err", err):
685             report.done(Path("f1"), black.Changed.NO)
686             self.assertEqual(len(out_lines), 0)
687             self.assertEqual(len(err_lines), 0)
688             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
689             self.assertEqual(report.return_code, 0)
690             report.done(Path("f2"), black.Changed.YES)
691             self.assertEqual(len(out_lines), 1)
692             self.assertEqual(len(err_lines), 0)
693             self.assertEqual(out_lines[-1], "reformatted f2")
694             self.assertEqual(
695                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
696             )
697             report.done(Path("f3"), black.Changed.CACHED)
698             self.assertEqual(len(out_lines), 1)
699             self.assertEqual(len(err_lines), 0)
700             self.assertEqual(out_lines[-1], "reformatted f2")
701             self.assertEqual(
702                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
703             )
704             self.assertEqual(report.return_code, 0)
705             report.check = True
706             self.assertEqual(report.return_code, 1)
707             report.check = False
708             report.failed(Path("e1"), "boom")
709             self.assertEqual(len(out_lines), 1)
710             self.assertEqual(len(err_lines), 1)
711             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
712             self.assertEqual(
713                 unstyle(str(report)),
714                 (
715                     "1 file reformatted, 2 files left unchanged, 1 file failed to"
716                     " reformat."
717                 ),
718             )
719             self.assertEqual(report.return_code, 123)
720             report.done(Path("f3"), black.Changed.YES)
721             self.assertEqual(len(out_lines), 2)
722             self.assertEqual(len(err_lines), 1)
723             self.assertEqual(out_lines[-1], "reformatted f3")
724             self.assertEqual(
725                 unstyle(str(report)),
726                 (
727                     "2 files reformatted, 2 files left unchanged, 1 file failed to"
728                     " reformat."
729                 ),
730             )
731             self.assertEqual(report.return_code, 123)
732             report.failed(Path("e2"), "boom")
733             self.assertEqual(len(out_lines), 2)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
736             self.assertEqual(
737                 unstyle(str(report)),
738                 (
739                     "2 files reformatted, 2 files left unchanged, 2 files failed to"
740                     " reformat."
741                 ),
742             )
743             self.assertEqual(report.return_code, 123)
744             report.path_ignored(Path("wat"), "no match")
745             self.assertEqual(len(out_lines), 2)
746             self.assertEqual(len(err_lines), 2)
747             self.assertEqual(
748                 unstyle(str(report)),
749                 (
750                     "2 files reformatted, 2 files left unchanged, 2 files failed to"
751                     " reformat."
752                 ),
753             )
754             self.assertEqual(report.return_code, 123)
755             report.done(Path("f4"), black.Changed.NO)
756             self.assertEqual(len(out_lines), 2)
757             self.assertEqual(len(err_lines), 2)
758             self.assertEqual(
759                 unstyle(str(report)),
760                 (
761                     "2 files reformatted, 3 files left unchanged, 2 files failed to"
762                     " reformat."
763                 ),
764             )
765             self.assertEqual(report.return_code, 123)
766             report.check = True
767             self.assertEqual(
768                 unstyle(str(report)),
769                 (
770                     "2 files would be reformatted, 3 files would be left unchanged, 2"
771                     " files would fail to reformat."
772                 ),
773             )
774             report.check = False
775             report.diff = True
776             self.assertEqual(
777                 unstyle(str(report)),
778                 (
779                     "2 files would be reformatted, 3 files would be left unchanged, 2"
780                     " files would fail to reformat."
781                 ),
782             )
783
784     def test_lib2to3_parse(self) -> None:
785         with self.assertRaises(black.InvalidInput):
786             black.lib2to3_parse("invalid syntax")
787
788         straddling = "x + y"
789         black.lib2to3_parse(straddling)
790         black.lib2to3_parse(straddling, {TargetVersion.PY36})
791
792         py2_only = "print x"
793         with self.assertRaises(black.InvalidInput):
794             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
795
796         py3_only = "exec(x, end=y)"
797         black.lib2to3_parse(py3_only)
798         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
799
800     def test_get_features_used_decorator(self) -> None:
801         # Test the feature detection of new decorator syntax
802         # since this makes some test cases of test_get_features_used()
803         # fails if it fails, this is tested first so that a useful case
804         # is identified
805         simples, relaxed = read_data("miscellaneous", "decorators")
806         # skip explanation comments at the top of the file
807         for simple_test in simples.split("##")[1:]:
808             node = black.lib2to3_parse(simple_test)
809             decorator = str(node.children[0].children[0]).strip()
810             self.assertNotIn(
811                 Feature.RELAXED_DECORATORS,
812                 black.get_features_used(node),
813                 msg=(
814                     f"decorator '{decorator}' follows python<=3.8 syntax"
815                     "but is detected as 3.9+"
816                     # f"The full node is\n{node!r}"
817                 ),
818             )
819         # skip the '# output' comment at the top of the output part
820         for relaxed_test in relaxed.split("##")[1:]:
821             node = black.lib2to3_parse(relaxed_test)
822             decorator = str(node.children[0].children[0]).strip()
823             self.assertIn(
824                 Feature.RELAXED_DECORATORS,
825                 black.get_features_used(node),
826                 msg=(
827                     f"decorator '{decorator}' uses python3.9+ syntax"
828                     "but is detected as python<=3.8"
829                     # f"The full node is\n{node!r}"
830                 ),
831             )
832
833     def test_get_features_used(self) -> None:
834         node = black.lib2to3_parse("def f(*, arg): ...\n")
835         self.assertEqual(black.get_features_used(node), set())
836         node = black.lib2to3_parse("def f(*, arg,): ...\n")
837         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
838         node = black.lib2to3_parse("f(*arg,)\n")
839         self.assertEqual(
840             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
841         )
842         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
843         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
844         node = black.lib2to3_parse("123_456\n")
845         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
846         node = black.lib2to3_parse("123456\n")
847         self.assertEqual(black.get_features_used(node), set())
848         source, expected = read_data("simple_cases", "function")
849         node = black.lib2to3_parse(source)
850         expected_features = {
851             Feature.TRAILING_COMMA_IN_CALL,
852             Feature.TRAILING_COMMA_IN_DEF,
853             Feature.F_STRINGS,
854         }
855         self.assertEqual(black.get_features_used(node), expected_features)
856         node = black.lib2to3_parse(expected)
857         self.assertEqual(black.get_features_used(node), expected_features)
858         source, expected = read_data("simple_cases", "expression")
859         node = black.lib2to3_parse(source)
860         self.assertEqual(black.get_features_used(node), set())
861         node = black.lib2to3_parse(expected)
862         self.assertEqual(black.get_features_used(node), set())
863         node = black.lib2to3_parse("lambda a, /, b: ...")
864         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
865         node = black.lib2to3_parse("def fn(a, /, b): ...")
866         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
867         node = black.lib2to3_parse("def fn(): yield a, b")
868         self.assertEqual(black.get_features_used(node), set())
869         node = black.lib2to3_parse("def fn(): return a, b")
870         self.assertEqual(black.get_features_used(node), set())
871         node = black.lib2to3_parse("def fn(): yield *b, c")
872         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
873         node = black.lib2to3_parse("def fn(): return a, *b, c")
874         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
875         node = black.lib2to3_parse("x = a, *b, c")
876         self.assertEqual(black.get_features_used(node), set())
877         node = black.lib2to3_parse("x: Any = regular")
878         self.assertEqual(black.get_features_used(node), set())
879         node = black.lib2to3_parse("x: Any = (regular, regular)")
880         self.assertEqual(black.get_features_used(node), set())
881         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
882         self.assertEqual(black.get_features_used(node), set())
883         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
884         self.assertEqual(
885             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
886         )
887         node = black.lib2to3_parse("try: pass\nexcept Something: pass")
888         self.assertEqual(black.get_features_used(node), set())
889         node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass")
890         self.assertEqual(black.get_features_used(node), set())
891         node = black.lib2to3_parse("try: pass\nexcept *Group: pass")
892         self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR})
893         node = black.lib2to3_parse("a[*b]")
894         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
895         node = black.lib2to3_parse("a[x, *y(), z] = t")
896         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
897         node = black.lib2to3_parse("def fn(*args: *T): pass")
898         self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS})
899
900     def test_get_features_used_for_future_flags(self) -> None:
901         for src, features in [
902             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
903             (
904                 "from __future__ import (other, annotations)",
905                 {Feature.FUTURE_ANNOTATIONS},
906             ),
907             ("a = 1 + 2\nfrom something import annotations", set()),
908             ("from __future__ import x, y", set()),
909         ]:
910             with self.subTest(src=src, features=features):
911                 node = black.lib2to3_parse(src)
912                 future_imports = black.get_future_imports(node)
913                 self.assertEqual(
914                     black.get_features_used(node, future_imports=future_imports),
915                     features,
916                 )
917
918     def test_get_future_imports(self) -> None:
919         node = black.lib2to3_parse("\n")
920         self.assertEqual(set(), black.get_future_imports(node))
921         node = black.lib2to3_parse("from __future__ import black\n")
922         self.assertEqual({"black"}, black.get_future_imports(node))
923         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
924         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
925         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
926         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
927         node = black.lib2to3_parse(
928             "from __future__ import multiple\nfrom __future__ import imports\n"
929         )
930         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
931         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
932         self.assertEqual({"black"}, black.get_future_imports(node))
933         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
934         self.assertEqual({"black"}, black.get_future_imports(node))
935         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
936         self.assertEqual(set(), black.get_future_imports(node))
937         node = black.lib2to3_parse("from some.module import black\n")
938         self.assertEqual(set(), black.get_future_imports(node))
939         node = black.lib2to3_parse(
940             "from __future__ import unicode_literals as _unicode_literals"
941         )
942         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
943         node = black.lib2to3_parse(
944             "from __future__ import unicode_literals as _lol, print"
945         )
946         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
947
948     @pytest.mark.incompatible_with_mypyc
949     def test_debug_visitor(self) -> None:
950         source, _ = read_data("miscellaneous", "debug_visitor")
951         expected, _ = read_data("miscellaneous", "debug_visitor.out")
952         out_lines = []
953         err_lines = []
954
955         def out(msg: str, **kwargs: Any) -> None:
956             out_lines.append(msg)
957
958         def err(msg: str, **kwargs: Any) -> None:
959             err_lines.append(msg)
960
961         with patch("black.debug.out", out):
962             DebugVisitor.show(source)
963         actual = "\n".join(out_lines) + "\n"
964         log_name = ""
965         if expected != actual:
966             log_name = black.dump_to_file(*out_lines)
967         self.assertEqual(
968             expected,
969             actual,
970             f"AST print out is different. Actual version dumped to {log_name}",
971         )
972
973     def test_format_file_contents(self) -> None:
974         empty = ""
975         mode = DEFAULT_MODE
976         with self.assertRaises(black.NothingChanged):
977             black.format_file_contents(empty, mode=mode, fast=False)
978         just_nl = "\n"
979         with self.assertRaises(black.NothingChanged):
980             black.format_file_contents(just_nl, mode=mode, fast=False)
981         same = "j = [1, 2, 3]\n"
982         with self.assertRaises(black.NothingChanged):
983             black.format_file_contents(same, mode=mode, fast=False)
984         different = "j = [1,2,3]"
985         expected = same
986         actual = black.format_file_contents(different, mode=mode, fast=False)
987         self.assertEqual(expected, actual)
988         invalid = "return if you can"
989         with self.assertRaises(black.InvalidInput) as e:
990             black.format_file_contents(invalid, mode=mode, fast=False)
991         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
992
993     def test_endmarker(self) -> None:
994         n = black.lib2to3_parse("\n")
995         self.assertEqual(n.type, black.syms.file_input)
996         self.assertEqual(len(n.children), 1)
997         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
998
999     @pytest.mark.incompatible_with_mypyc
1000     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1001     def test_assertFormatEqual(self) -> None:
1002         out_lines = []
1003         err_lines = []
1004
1005         def out(msg: str, **kwargs: Any) -> None:
1006             out_lines.append(msg)
1007
1008         def err(msg: str, **kwargs: Any) -> None:
1009             err_lines.append(msg)
1010
1011         with patch("black.output._out", out), patch("black.output._err", err):
1012             with self.assertRaises(AssertionError):
1013                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1014
1015         out_str = "".join(out_lines)
1016         self.assertIn("Expected tree:", out_str)
1017         self.assertIn("Actual tree:", out_str)
1018         self.assertEqual("".join(err_lines), "")
1019
1020     @event_loop()
1021     @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1022     def test_works_in_mono_process_only_environment(self) -> None:
1023         with cache_dir() as workspace:
1024             for f in [
1025                 (workspace / "one.py").resolve(),
1026                 (workspace / "two.py").resolve(),
1027             ]:
1028                 f.write_text('print("hello")\n')
1029             self.invokeBlack([str(workspace)])
1030
1031     @event_loop()
1032     def test_check_diff_use_together(self) -> None:
1033         with cache_dir():
1034             # Files which will be reformatted.
1035             src1 = get_case_path("miscellaneous", "string_quotes")
1036             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1037             # Files which will not be reformatted.
1038             src2 = get_case_path("simple_cases", "composition")
1039             self.invokeBlack([str(src2), "--diff", "--check"])
1040             # Multi file command.
1041             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1042
1043     def test_no_src_fails(self) -> None:
1044         with cache_dir():
1045             self.invokeBlack([], exit_code=1)
1046
1047     def test_src_and_code_fails(self) -> None:
1048         with cache_dir():
1049             self.invokeBlack([".", "-c", "0"], exit_code=1)
1050
1051     def test_broken_symlink(self) -> None:
1052         with cache_dir() as workspace:
1053             symlink = workspace / "broken_link.py"
1054             try:
1055                 symlink.symlink_to("nonexistent.py")
1056             except (OSError, NotImplementedError) as e:
1057                 self.skipTest(f"Can't create symlinks: {e}")
1058             self.invokeBlack([str(workspace.resolve())])
1059
1060     def test_single_file_force_pyi(self) -> None:
1061         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1062         contents, expected = read_data("miscellaneous", "force_pyi")
1063         with cache_dir() as workspace:
1064             path = (workspace / "file.py").resolve()
1065             with open(path, "w") as fh:
1066                 fh.write(contents)
1067             self.invokeBlack([str(path), "--pyi"])
1068             with open(path, "r") as fh:
1069                 actual = fh.read()
1070             # verify cache with --pyi is separate
1071             pyi_cache = black.read_cache(pyi_mode)
1072             self.assertIn(str(path), pyi_cache)
1073             normal_cache = black.read_cache(DEFAULT_MODE)
1074             self.assertNotIn(str(path), normal_cache)
1075         self.assertFormatEqual(expected, actual)
1076         black.assert_equivalent(contents, actual)
1077         black.assert_stable(contents, actual, pyi_mode)
1078
1079     @event_loop()
1080     def test_multi_file_force_pyi(self) -> None:
1081         reg_mode = DEFAULT_MODE
1082         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1083         contents, expected = read_data("miscellaneous", "force_pyi")
1084         with cache_dir() as workspace:
1085             paths = [
1086                 (workspace / "file1.py").resolve(),
1087                 (workspace / "file2.py").resolve(),
1088             ]
1089             for path in paths:
1090                 with open(path, "w") as fh:
1091                     fh.write(contents)
1092             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1093             for path in paths:
1094                 with open(path, "r") as fh:
1095                     actual = fh.read()
1096                 self.assertEqual(actual, expected)
1097             # verify cache with --pyi is separate
1098             pyi_cache = black.read_cache(pyi_mode)
1099             normal_cache = black.read_cache(reg_mode)
1100             for path in paths:
1101                 self.assertIn(str(path), pyi_cache)
1102                 self.assertNotIn(str(path), normal_cache)
1103
1104     def test_pipe_force_pyi(self) -> None:
1105         source, expected = read_data("miscellaneous", "force_pyi")
1106         result = CliRunner().invoke(
1107             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1108         )
1109         self.assertEqual(result.exit_code, 0)
1110         actual = result.output
1111         self.assertFormatEqual(actual, expected)
1112
1113     def test_single_file_force_py36(self) -> None:
1114         reg_mode = DEFAULT_MODE
1115         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1116         source, expected = read_data("miscellaneous", "force_py36")
1117         with cache_dir() as workspace:
1118             path = (workspace / "file.py").resolve()
1119             with open(path, "w") as fh:
1120                 fh.write(source)
1121             self.invokeBlack([str(path), *PY36_ARGS])
1122             with open(path, "r") as fh:
1123                 actual = fh.read()
1124             # verify cache with --target-version is separate
1125             py36_cache = black.read_cache(py36_mode)
1126             self.assertIn(str(path), py36_cache)
1127             normal_cache = black.read_cache(reg_mode)
1128             self.assertNotIn(str(path), normal_cache)
1129         self.assertEqual(actual, expected)
1130
1131     @event_loop()
1132     def test_multi_file_force_py36(self) -> None:
1133         reg_mode = DEFAULT_MODE
1134         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1135         source, expected = read_data("miscellaneous", "force_py36")
1136         with cache_dir() as workspace:
1137             paths = [
1138                 (workspace / "file1.py").resolve(),
1139                 (workspace / "file2.py").resolve(),
1140             ]
1141             for path in paths:
1142                 with open(path, "w") as fh:
1143                     fh.write(source)
1144             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1145             for path in paths:
1146                 with open(path, "r") as fh:
1147                     actual = fh.read()
1148                 self.assertEqual(actual, expected)
1149             # verify cache with --target-version is separate
1150             pyi_cache = black.read_cache(py36_mode)
1151             normal_cache = black.read_cache(reg_mode)
1152             for path in paths:
1153                 self.assertIn(str(path), pyi_cache)
1154                 self.assertNotIn(str(path), normal_cache)
1155
1156     def test_pipe_force_py36(self) -> None:
1157         source, expected = read_data("miscellaneous", "force_py36")
1158         result = CliRunner().invoke(
1159             black.main,
1160             ["-", "-q", "--target-version=py36"],
1161             input=BytesIO(source.encode("utf8")),
1162         )
1163         self.assertEqual(result.exit_code, 0)
1164         actual = result.output
1165         self.assertFormatEqual(actual, expected)
1166
1167     @pytest.mark.incompatible_with_mypyc
1168     def test_reformat_one_with_stdin(self) -> None:
1169         with patch(
1170             "black.format_stdin_to_stdout",
1171             return_value=lambda *args, **kwargs: black.Changed.YES,
1172         ) as fsts:
1173             report = MagicMock()
1174             path = Path("-")
1175             black.reformat_one(
1176                 path,
1177                 fast=True,
1178                 write_back=black.WriteBack.YES,
1179                 mode=DEFAULT_MODE,
1180                 report=report,
1181             )
1182             fsts.assert_called_once()
1183             report.done.assert_called_with(path, black.Changed.YES)
1184
1185     @pytest.mark.incompatible_with_mypyc
1186     def test_reformat_one_with_stdin_filename(self) -> None:
1187         with patch(
1188             "black.format_stdin_to_stdout",
1189             return_value=lambda *args, **kwargs: black.Changed.YES,
1190         ) as fsts:
1191             report = MagicMock()
1192             p = "foo.py"
1193             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1194             expected = Path(p)
1195             black.reformat_one(
1196                 path,
1197                 fast=True,
1198                 write_back=black.WriteBack.YES,
1199                 mode=DEFAULT_MODE,
1200                 report=report,
1201             )
1202             fsts.assert_called_once_with(
1203                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1204             )
1205             # __BLACK_STDIN_FILENAME__ should have been stripped
1206             report.done.assert_called_with(expected, black.Changed.YES)
1207
1208     @pytest.mark.incompatible_with_mypyc
1209     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1210         with patch(
1211             "black.format_stdin_to_stdout",
1212             return_value=lambda *args, **kwargs: black.Changed.YES,
1213         ) as fsts:
1214             report = MagicMock()
1215             p = "foo.pyi"
1216             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1217             expected = Path(p)
1218             black.reformat_one(
1219                 path,
1220                 fast=True,
1221                 write_back=black.WriteBack.YES,
1222                 mode=DEFAULT_MODE,
1223                 report=report,
1224             )
1225             fsts.assert_called_once_with(
1226                 fast=True,
1227                 write_back=black.WriteBack.YES,
1228                 mode=replace(DEFAULT_MODE, is_pyi=True),
1229             )
1230             # __BLACK_STDIN_FILENAME__ should have been stripped
1231             report.done.assert_called_with(expected, black.Changed.YES)
1232
1233     @pytest.mark.incompatible_with_mypyc
1234     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1235         with patch(
1236             "black.format_stdin_to_stdout",
1237             return_value=lambda *args, **kwargs: black.Changed.YES,
1238         ) as fsts:
1239             report = MagicMock()
1240             p = "foo.ipynb"
1241             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1242             expected = Path(p)
1243             black.reformat_one(
1244                 path,
1245                 fast=True,
1246                 write_back=black.WriteBack.YES,
1247                 mode=DEFAULT_MODE,
1248                 report=report,
1249             )
1250             fsts.assert_called_once_with(
1251                 fast=True,
1252                 write_back=black.WriteBack.YES,
1253                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1254             )
1255             # __BLACK_STDIN_FILENAME__ should have been stripped
1256             report.done.assert_called_with(expected, black.Changed.YES)
1257
1258     @pytest.mark.incompatible_with_mypyc
1259     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1260         with patch(
1261             "black.format_stdin_to_stdout",
1262             return_value=lambda *args, **kwargs: black.Changed.YES,
1263         ) as fsts:
1264             report = MagicMock()
1265             # Even with an existing file, since we are forcing stdin, black
1266             # should output to stdout and not modify the file inplace
1267             p = THIS_DIR / "data" / "simple_cases" / "collections.py"
1268             # Make sure is_file actually returns True
1269             self.assertTrue(p.is_file())
1270             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1271             expected = Path(p)
1272             black.reformat_one(
1273                 path,
1274                 fast=True,
1275                 write_back=black.WriteBack.YES,
1276                 mode=DEFAULT_MODE,
1277                 report=report,
1278             )
1279             fsts.assert_called_once()
1280             # __BLACK_STDIN_FILENAME__ should have been stripped
1281             report.done.assert_called_with(expected, black.Changed.YES)
1282
1283     def test_reformat_one_with_stdin_empty(self) -> None:
1284         output = io.StringIO()
1285         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1286             try:
1287                 black.format_stdin_to_stdout(
1288                     fast=True,
1289                     content="",
1290                     write_back=black.WriteBack.YES,
1291                     mode=DEFAULT_MODE,
1292                 )
1293             except io.UnsupportedOperation:
1294                 pass  # StringIO does not support detach
1295             assert output.getvalue() == ""
1296
1297     def test_invalid_cli_regex(self) -> None:
1298         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1299             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1300
1301     def test_required_version_matches_version(self) -> None:
1302         self.invokeBlack(
1303             ["--required-version", black.__version__, "-c", "0"],
1304             exit_code=0,
1305             ignore_config=True,
1306         )
1307
1308     def test_required_version_matches_partial_version(self) -> None:
1309         self.invokeBlack(
1310             ["--required-version", black.__version__.split(".")[0], "-c", "0"],
1311             exit_code=0,
1312             ignore_config=True,
1313         )
1314
1315     def test_required_version_does_not_match_on_minor_version(self) -> None:
1316         self.invokeBlack(
1317             ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
1318             exit_code=1,
1319             ignore_config=True,
1320         )
1321
1322     def test_required_version_does_not_match_version(self) -> None:
1323         result = BlackRunner().invoke(
1324             black.main,
1325             ["--required-version", "20.99b", "-c", "0"],
1326         )
1327         self.assertEqual(result.exit_code, 1)
1328         self.assertIn("required version", result.stderr)
1329
1330     def test_preserves_line_endings(self) -> None:
1331         with TemporaryDirectory() as workspace:
1332             test_file = Path(workspace) / "test.py"
1333             for nl in ["\n", "\r\n"]:
1334                 contents = nl.join(["def f(  ):", "    pass"])
1335                 test_file.write_bytes(contents.encode())
1336                 ff(test_file, write_back=black.WriteBack.YES)
1337                 updated_contents: bytes = test_file.read_bytes()
1338                 self.assertIn(nl.encode(), updated_contents)
1339                 if nl == "\n":
1340                     self.assertNotIn(b"\r\n", updated_contents)
1341
1342     def test_preserves_line_endings_via_stdin(self) -> None:
1343         for nl in ["\n", "\r\n"]:
1344             contents = nl.join(["def f(  ):", "    pass"])
1345             runner = BlackRunner()
1346             result = runner.invoke(
1347                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1348             )
1349             self.assertEqual(result.exit_code, 0)
1350             output = result.stdout_bytes
1351             self.assertIn(nl.encode("utf8"), output)
1352             if nl == "\n":
1353                 self.assertNotIn(b"\r\n", output)
1354
1355     def test_normalize_line_endings(self) -> None:
1356         with TemporaryDirectory() as workspace:
1357             test_file = Path(workspace) / "test.py"
1358             for data, expected in (
1359                 (b"c\r\nc\n ", b"c\r\nc\r\n"),
1360                 (b"l\nl\r\n ", b"l\nl\n"),
1361             ):
1362                 test_file.write_bytes(data)
1363                 ff(test_file, write_back=black.WriteBack.YES)
1364                 self.assertEqual(test_file.read_bytes(), expected)
1365
1366     def test_assert_equivalent_different_asts(self) -> None:
1367         with self.assertRaises(AssertionError):
1368             black.assert_equivalent("{}", "None")
1369
1370     def test_shhh_click(self) -> None:
1371         try:
1372             from click import _unicodefun  # type: ignore
1373         except ImportError:
1374             self.skipTest("Incompatible Click version")
1375
1376         if not hasattr(_unicodefun, "_verify_python_env"):
1377             self.skipTest("Incompatible Click version")
1378
1379         # First, let's see if Click is crashing with a preferred ASCII charset.
1380         with patch("locale.getpreferredencoding") as gpe:
1381             gpe.return_value = "ASCII"
1382             with self.assertRaises(RuntimeError):
1383                 _unicodefun._verify_python_env()
1384         # Now, let's silence Click...
1385         black.patch_click()
1386         # ...and confirm it's silent.
1387         with patch("locale.getpreferredencoding") as gpe:
1388             gpe.return_value = "ASCII"
1389             try:
1390                 _unicodefun._verify_python_env()
1391             except RuntimeError as re:
1392                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1393
1394     def test_root_logger_not_used_directly(self) -> None:
1395         def fail(*args: Any, **kwargs: Any) -> None:
1396             self.fail("Record created with root logger")
1397
1398         with patch.multiple(
1399             logging.root,
1400             debug=fail,
1401             info=fail,
1402             warning=fail,
1403             error=fail,
1404             critical=fail,
1405             log=fail,
1406         ):
1407             ff(THIS_DIR / "util.py")
1408
1409     def test_invalid_config_return_code(self) -> None:
1410         tmp_file = Path(black.dump_to_file())
1411         try:
1412             tmp_config = Path(black.dump_to_file())
1413             tmp_config.unlink()
1414             args = ["--config", str(tmp_config), str(tmp_file)]
1415             self.invokeBlack(args, exit_code=2, ignore_config=False)
1416         finally:
1417             tmp_file.unlink()
1418
1419     def test_parse_pyproject_toml(self) -> None:
1420         test_toml_file = THIS_DIR / "test.toml"
1421         config = black.parse_pyproject_toml(str(test_toml_file))
1422         self.assertEqual(config["verbose"], 1)
1423         self.assertEqual(config["check"], "no")
1424         self.assertEqual(config["diff"], "y")
1425         self.assertEqual(config["color"], True)
1426         self.assertEqual(config["line_length"], 79)
1427         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1428         self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
1429         self.assertEqual(config["exclude"], r"\.pyi?$")
1430         self.assertEqual(config["include"], r"\.py?$")
1431
1432     def test_read_pyproject_toml(self) -> None:
1433         test_toml_file = THIS_DIR / "test.toml"
1434         fake_ctx = FakeContext()
1435         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1436         config = fake_ctx.default_map
1437         self.assertEqual(config["verbose"], "1")
1438         self.assertEqual(config["check"], "no")
1439         self.assertEqual(config["diff"], "y")
1440         self.assertEqual(config["color"], "True")
1441         self.assertEqual(config["line_length"], "79")
1442         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1443         self.assertEqual(config["exclude"], r"\.pyi?$")
1444         self.assertEqual(config["include"], r"\.py?$")
1445
1446     @pytest.mark.incompatible_with_mypyc
1447     def test_find_project_root(self) -> None:
1448         with TemporaryDirectory() as workspace:
1449             root = Path(workspace)
1450             test_dir = root / "test"
1451             test_dir.mkdir()
1452
1453             src_dir = root / "src"
1454             src_dir.mkdir()
1455
1456             root_pyproject = root / "pyproject.toml"
1457             root_pyproject.touch()
1458             src_pyproject = src_dir / "pyproject.toml"
1459             src_pyproject.touch()
1460             src_python = src_dir / "foo.py"
1461             src_python.touch()
1462
1463             self.assertEqual(
1464                 black.find_project_root((src_dir, test_dir)),
1465                 (root.resolve(), "pyproject.toml"),
1466             )
1467             self.assertEqual(
1468                 black.find_project_root((src_dir,)),
1469                 (src_dir.resolve(), "pyproject.toml"),
1470             )
1471             self.assertEqual(
1472                 black.find_project_root((src_python,)),
1473                 (src_dir.resolve(), "pyproject.toml"),
1474             )
1475
1476             with change_directory(test_dir):
1477                 self.assertEqual(
1478                     black.find_project_root(("-",), stdin_filename="../src/a.py"),
1479                     (src_dir.resolve(), "pyproject.toml"),
1480                 )
1481
1482     @patch(
1483         "black.files.find_user_pyproject_toml",
1484     )
1485     def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None:
1486         find_user_pyproject_toml.side_effect = RuntimeError()
1487
1488         with redirect_stderr(io.StringIO()) as stderr:
1489             result = black.files.find_pyproject_toml(
1490                 path_search_start=(str(Path.cwd().root),)
1491             )
1492
1493         assert result is None
1494         err = stderr.getvalue()
1495         assert "Ignoring user configuration" in err
1496
1497     @patch(
1498         "black.files.find_user_pyproject_toml",
1499         black.files.find_user_pyproject_toml.__wrapped__,
1500     )
1501     def test_find_user_pyproject_toml_linux(self) -> None:
1502         if system() == "Windows":
1503             return
1504
1505         # Test if XDG_CONFIG_HOME is checked
1506         with TemporaryDirectory() as workspace:
1507             tmp_user_config = Path(workspace) / "black"
1508             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1509                 self.assertEqual(
1510                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1511                 )
1512
1513         # Test fallback for XDG_CONFIG_HOME
1514         with patch.dict("os.environ"):
1515             os.environ.pop("XDG_CONFIG_HOME", None)
1516             fallback_user_config = Path("~/.config").expanduser() / "black"
1517             self.assertEqual(
1518                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1519             )
1520
1521     def test_find_user_pyproject_toml_windows(self) -> None:
1522         if system() != "Windows":
1523             return
1524
1525         user_config_path = Path.home() / ".black"
1526         self.assertEqual(
1527             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1528         )
1529
1530     def test_bpo_33660_workaround(self) -> None:
1531         if system() == "Windows":
1532             return
1533
1534         # https://bugs.python.org/issue33660
1535         root = Path("/")
1536         with change_directory(root):
1537             path = Path("workspace") / "project"
1538             report = black.Report(verbose=True)
1539             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1540             self.assertEqual(normalized_path, "workspace/project")
1541
1542     def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None:
1543         if system() != "Windows":
1544             return
1545
1546         with TemporaryDirectory() as workspace:
1547             root = Path(workspace)
1548             junction_dir = root / "junction"
1549             junction_target_outside_of_root = root / ".."
1550             os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}")
1551
1552             report = black.Report(verbose=True)
1553             normalized_path = black.normalize_path_maybe_ignore(
1554                 junction_dir, root, report
1555             )
1556             # Manually delete for Python < 3.8
1557             os.system(f"rmdir {junction_dir}")
1558
1559             self.assertEqual(normalized_path, None)
1560
1561     def test_newline_comment_interaction(self) -> None:
1562         source = "class A:\\\r\n# type: ignore\n pass\n"
1563         output = black.format_str(source, mode=DEFAULT_MODE)
1564         black.assert_stable(source, output, mode=DEFAULT_MODE)
1565
1566     def test_bpo_2142_workaround(self) -> None:
1567         # https://bugs.python.org/issue2142
1568
1569         source, _ = read_data("miscellaneous", "missing_final_newline")
1570         # read_data adds a trailing newline
1571         source = source.rstrip()
1572         expected, _ = read_data("miscellaneous", "missing_final_newline.diff")
1573         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1574         diff_header = re.compile(
1575             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1576             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1577         )
1578         try:
1579             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1580             self.assertEqual(result.exit_code, 0)
1581         finally:
1582             os.unlink(tmp_file)
1583         actual = result.output
1584         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1585         self.assertEqual(actual, expected)
1586
1587     @staticmethod
1588     def compare_results(
1589         result: click.testing.Result, expected_value: str, expected_exit_code: int
1590     ) -> None:
1591         """Helper method to test the value and exit code of a click Result."""
1592         assert (
1593             result.output == expected_value
1594         ), "The output did not match the expected value."
1595         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1596
1597     def test_code_option(self) -> None:
1598         """Test the code option with no changes."""
1599         code = 'print("Hello world")\n'
1600         args = ["--code", code]
1601         result = CliRunner().invoke(black.main, args)
1602
1603         self.compare_results(result, code, 0)
1604
1605     def test_code_option_changed(self) -> None:
1606         """Test the code option when changes are required."""
1607         code = "print('hello world')"
1608         formatted = black.format_str(code, mode=DEFAULT_MODE)
1609
1610         args = ["--code", code]
1611         result = CliRunner().invoke(black.main, args)
1612
1613         self.compare_results(result, formatted, 0)
1614
1615     def test_code_option_check(self) -> None:
1616         """Test the code option when check is passed."""
1617         args = ["--check", "--code", 'print("Hello world")\n']
1618         result = CliRunner().invoke(black.main, args)
1619         self.compare_results(result, "", 0)
1620
1621     def test_code_option_check_changed(self) -> None:
1622         """Test the code option when changes are required, and check is passed."""
1623         args = ["--check", "--code", "print('hello world')"]
1624         result = CliRunner().invoke(black.main, args)
1625         self.compare_results(result, "", 1)
1626
1627     def test_code_option_diff(self) -> None:
1628         """Test the code option when diff is passed."""
1629         code = "print('hello world')"
1630         formatted = black.format_str(code, mode=DEFAULT_MODE)
1631         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1632
1633         args = ["--diff", "--code", code]
1634         result = CliRunner().invoke(black.main, args)
1635
1636         # Remove time from diff
1637         output = DIFF_TIME.sub("", result.output)
1638
1639         assert output == result_diff, "The output did not match the expected value."
1640         assert result.exit_code == 0, "The exit code is incorrect."
1641
1642     def test_code_option_color_diff(self) -> None:
1643         """Test the code option when color and diff are passed."""
1644         code = "print('hello world')"
1645         formatted = black.format_str(code, mode=DEFAULT_MODE)
1646
1647         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1648         result_diff = color_diff(result_diff)
1649
1650         args = ["--diff", "--color", "--code", code]
1651         result = CliRunner().invoke(black.main, args)
1652
1653         # Remove time from diff
1654         output = DIFF_TIME.sub("", result.output)
1655
1656         assert output == result_diff, "The output did not match the expected value."
1657         assert result.exit_code == 0, "The exit code is incorrect."
1658
1659     @pytest.mark.incompatible_with_mypyc
1660     def test_code_option_safe(self) -> None:
1661         """Test that the code option throws an error when the sanity checks fail."""
1662         # Patch black.assert_equivalent to ensure the sanity checks fail
1663         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1664             code = 'print("Hello world")'
1665             error_msg = f"{code}\nerror: cannot format <string>: \n"
1666
1667             args = ["--safe", "--code", code]
1668             result = CliRunner().invoke(black.main, args)
1669
1670             self.compare_results(result, error_msg, 123)
1671
1672     def test_code_option_fast(self) -> None:
1673         """Test that the code option ignores errors when the sanity checks fail."""
1674         # Patch black.assert_equivalent to ensure the sanity checks fail
1675         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1676             code = 'print("Hello world")'
1677             formatted = black.format_str(code, mode=DEFAULT_MODE)
1678
1679             args = ["--fast", "--code", code]
1680             result = CliRunner().invoke(black.main, args)
1681
1682             self.compare_results(result, formatted, 0)
1683
1684     @pytest.mark.incompatible_with_mypyc
1685     def test_code_option_config(self) -> None:
1686         """
1687         Test that the code option finds the pyproject.toml in the current directory.
1688         """
1689         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1690             args = ["--code", "print"]
1691             # This is the only directory known to contain a pyproject.toml
1692             with change_directory(PROJECT_ROOT):
1693                 CliRunner().invoke(black.main, args)
1694                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1695
1696             assert (
1697                 len(parse.mock_calls) >= 1
1698             ), "Expected config parse to be called with the current directory."
1699
1700             _, call_args, _ = parse.mock_calls[0]
1701             assert (
1702                 call_args[0].lower() == str(pyproject_path).lower()
1703             ), "Incorrect config loaded."
1704
1705     @pytest.mark.incompatible_with_mypyc
1706     def test_code_option_parent_config(self) -> None:
1707         """
1708         Test that the code option finds the pyproject.toml in the parent directory.
1709         """
1710         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1711             with change_directory(THIS_DIR):
1712                 args = ["--code", "print"]
1713                 CliRunner().invoke(black.main, args)
1714
1715                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1716                 assert (
1717                     len(parse.mock_calls) >= 1
1718                 ), "Expected config parse to be called with the current directory."
1719
1720                 _, call_args, _ = parse.mock_calls[0]
1721                 assert (
1722                     call_args[0].lower() == str(pyproject_path).lower()
1723                 ), "Incorrect config loaded."
1724
1725     def test_for_handled_unexpected_eof_error(self) -> None:
1726         """
1727         Test that an unexpected EOF SyntaxError is nicely presented.
1728         """
1729         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1730             black.lib2to3_parse("print(", {})
1731
1732         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1733
1734     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1735         with pytest.raises(AssertionError) as err:
1736             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1737
1738         err.match("--safe")
1739         # Unfortunately the SyntaxError message has changed in newer versions so we
1740         # can't match it directly.
1741         err.match("invalid character")
1742         err.match(r"\(<unknown>, line 1\)")
1743
1744
1745 class TestCaching:
1746     def test_get_cache_dir(
1747         self,
1748         tmp_path: Path,
1749         monkeypatch: pytest.MonkeyPatch,
1750     ) -> None:
1751         # Create multiple cache directories
1752         workspace1 = tmp_path / "ws1"
1753         workspace1.mkdir()
1754         workspace2 = tmp_path / "ws2"
1755         workspace2.mkdir()
1756
1757         # Force user_cache_dir to use the temporary directory for easier assertions
1758         patch_user_cache_dir = patch(
1759             target="black.cache.user_cache_dir",
1760             autospec=True,
1761             return_value=str(workspace1),
1762         )
1763
1764         # If BLACK_CACHE_DIR is not set, use user_cache_dir
1765         monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
1766         with patch_user_cache_dir:
1767             assert get_cache_dir() == workspace1
1768
1769         # If it is set, use the path provided in the env var.
1770         monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
1771         assert get_cache_dir() == workspace2
1772
1773     def test_cache_broken_file(self) -> None:
1774         mode = DEFAULT_MODE
1775         with cache_dir() as workspace:
1776             cache_file = get_cache_file(mode)
1777             cache_file.write_text("this is not a pickle")
1778             assert black.read_cache(mode) == {}
1779             src = (workspace / "test.py").resolve()
1780             src.write_text("print('hello')")
1781             invokeBlack([str(src)])
1782             cache = black.read_cache(mode)
1783             assert str(src) in cache
1784
1785     def test_cache_single_file_already_cached(self) -> None:
1786         mode = DEFAULT_MODE
1787         with cache_dir() as workspace:
1788             src = (workspace / "test.py").resolve()
1789             src.write_text("print('hello')")
1790             black.write_cache({}, [src], mode)
1791             invokeBlack([str(src)])
1792             assert src.read_text() == "print('hello')"
1793
1794     @event_loop()
1795     def test_cache_multiple_files(self) -> None:
1796         mode = DEFAULT_MODE
1797         with cache_dir() as workspace, patch(
1798             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1799         ):
1800             one = (workspace / "one.py").resolve()
1801             with one.open("w") as fobj:
1802                 fobj.write("print('hello')")
1803             two = (workspace / "two.py").resolve()
1804             with two.open("w") as fobj:
1805                 fobj.write("print('hello')")
1806             black.write_cache({}, [one], mode)
1807             invokeBlack([str(workspace)])
1808             with one.open("r") as fobj:
1809                 assert fobj.read() == "print('hello')"
1810             with two.open("r") as fobj:
1811                 assert fobj.read() == 'print("hello")\n'
1812             cache = black.read_cache(mode)
1813             assert str(one) in cache
1814             assert str(two) in cache
1815
1816     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1817     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1818         mode = DEFAULT_MODE
1819         with cache_dir() as workspace:
1820             src = (workspace / "test.py").resolve()
1821             with src.open("w") as fobj:
1822                 fobj.write("print('hello')")
1823             with patch("black.read_cache") as read_cache, patch(
1824                 "black.write_cache"
1825             ) as write_cache:
1826                 cmd = [str(src), "--diff"]
1827                 if color:
1828                     cmd.append("--color")
1829                 invokeBlack(cmd)
1830                 cache_file = get_cache_file(mode)
1831                 assert cache_file.exists() is False
1832                 write_cache.assert_not_called()
1833                 read_cache.assert_not_called()
1834
1835     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1836     @event_loop()
1837     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1838         with cache_dir() as workspace:
1839             for tag in range(0, 4):
1840                 src = (workspace / f"test{tag}.py").resolve()
1841                 with src.open("w") as fobj:
1842                     fobj.write("print('hello')")
1843             with patch(
1844                 "black.concurrency.Manager", wraps=multiprocessing.Manager
1845             ) as mgr:
1846                 cmd = ["--diff", str(workspace)]
1847                 if color:
1848                     cmd.append("--color")
1849                 invokeBlack(cmd, exit_code=0)
1850                 # this isn't quite doing what we want, but if it _isn't_
1851                 # called then we cannot be using the lock it provides
1852                 mgr.assert_called()
1853
1854     def test_no_cache_when_stdin(self) -> None:
1855         mode = DEFAULT_MODE
1856         with cache_dir():
1857             result = CliRunner().invoke(
1858                 black.main, ["-"], input=BytesIO(b"print('hello')")
1859             )
1860             assert not result.exit_code
1861             cache_file = get_cache_file(mode)
1862             assert not cache_file.exists()
1863
1864     def test_read_cache_no_cachefile(self) -> None:
1865         mode = DEFAULT_MODE
1866         with cache_dir():
1867             assert black.read_cache(mode) == {}
1868
1869     def test_write_cache_read_cache(self) -> None:
1870         mode = DEFAULT_MODE
1871         with cache_dir() as workspace:
1872             src = (workspace / "test.py").resolve()
1873             src.touch()
1874             black.write_cache({}, [src], mode)
1875             cache = black.read_cache(mode)
1876             assert str(src) in cache
1877             assert cache[str(src)] == black.get_cache_info(src)
1878
1879     def test_filter_cached(self) -> None:
1880         with TemporaryDirectory() as workspace:
1881             path = Path(workspace)
1882             uncached = (path / "uncached").resolve()
1883             cached = (path / "cached").resolve()
1884             cached_but_changed = (path / "changed").resolve()
1885             uncached.touch()
1886             cached.touch()
1887             cached_but_changed.touch()
1888             cache = {
1889                 str(cached): black.get_cache_info(cached),
1890                 str(cached_but_changed): (0.0, 0),
1891             }
1892             todo, done = black.cache.filter_cached(
1893                 cache, {uncached, cached, cached_but_changed}
1894             )
1895             assert todo == {uncached, cached_but_changed}
1896             assert done == {cached}
1897
1898     def test_write_cache_creates_directory_if_needed(self) -> None:
1899         mode = DEFAULT_MODE
1900         with cache_dir(exists=False) as workspace:
1901             assert not workspace.exists()
1902             black.write_cache({}, [], mode)
1903             assert workspace.exists()
1904
1905     @event_loop()
1906     def test_failed_formatting_does_not_get_cached(self) -> None:
1907         mode = DEFAULT_MODE
1908         with cache_dir() as workspace, patch(
1909             "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
1910         ):
1911             failing = (workspace / "failing.py").resolve()
1912             with failing.open("w") as fobj:
1913                 fobj.write("not actually python")
1914             clean = (workspace / "clean.py").resolve()
1915             with clean.open("w") as fobj:
1916                 fobj.write('print("hello")\n')
1917             invokeBlack([str(workspace)], exit_code=123)
1918             cache = black.read_cache(mode)
1919             assert str(failing) not in cache
1920             assert str(clean) in cache
1921
1922     def test_write_cache_write_fail(self) -> None:
1923         mode = DEFAULT_MODE
1924         with cache_dir(), patch.object(Path, "open") as mock:
1925             mock.side_effect = OSError
1926             black.write_cache({}, [], mode)
1927
1928     def test_read_cache_line_lengths(self) -> None:
1929         mode = DEFAULT_MODE
1930         short_mode = replace(DEFAULT_MODE, line_length=1)
1931         with cache_dir() as workspace:
1932             path = (workspace / "file.py").resolve()
1933             path.touch()
1934             black.write_cache({}, [path], mode)
1935             one = black.read_cache(mode)
1936             assert str(path) in one
1937             two = black.read_cache(short_mode)
1938             assert str(path) not in two
1939
1940
1941 def assert_collected_sources(
1942     src: Sequence[Union[str, Path]],
1943     expected: Sequence[Union[str, Path]],
1944     *,
1945     ctx: Optional[FakeContext] = None,
1946     exclude: Optional[str] = None,
1947     include: Optional[str] = None,
1948     extend_exclude: Optional[str] = None,
1949     force_exclude: Optional[str] = None,
1950     stdin_filename: Optional[str] = None,
1951 ) -> None:
1952     gs_src = tuple(str(Path(s)) for s in src)
1953     gs_expected = [Path(s) for s in expected]
1954     gs_exclude = None if exclude is None else compile_pattern(exclude)
1955     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1956     gs_extend_exclude = (
1957         None if extend_exclude is None else compile_pattern(extend_exclude)
1958     )
1959     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1960     collected = black.get_sources(
1961         ctx=ctx or FakeContext(),
1962         src=gs_src,
1963         quiet=False,
1964         verbose=False,
1965         include=gs_include,
1966         exclude=gs_exclude,
1967         extend_exclude=gs_extend_exclude,
1968         force_exclude=gs_force_exclude,
1969         report=black.Report(),
1970         stdin_filename=stdin_filename,
1971     )
1972     assert sorted(collected) == sorted(gs_expected)
1973
1974
1975 class TestFileCollection:
1976     def test_include_exclude(self) -> None:
1977         path = THIS_DIR / "data" / "include_exclude_tests"
1978         src = [path]
1979         expected = [
1980             Path(path / "b/dont_exclude/a.py"),
1981             Path(path / "b/dont_exclude/a.pyi"),
1982         ]
1983         assert_collected_sources(
1984             src,
1985             expected,
1986             include=r"\.pyi?$",
1987             exclude=r"/exclude/|/\.definitely_exclude/",
1988         )
1989
1990     def test_gitignore_used_as_default(self) -> None:
1991         base = Path(DATA_DIR / "include_exclude_tests")
1992         expected = [
1993             base / "b/.definitely_exclude/a.py",
1994             base / "b/.definitely_exclude/a.pyi",
1995         ]
1996         src = [base / "b/"]
1997         ctx = FakeContext()
1998         ctx.obj["root"] = base
1999         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
2000
2001     def test_gitignore_used_on_multiple_sources(self) -> None:
2002         root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
2003         expected = [
2004             root / "dir1" / "b.py",
2005             root / "dir2" / "b.py",
2006         ]
2007         ctx = FakeContext()
2008         ctx.obj["root"] = root
2009         src = [root / "dir1", root / "dir2"]
2010         assert_collected_sources(src, expected, ctx=ctx)
2011
2012     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2013     def test_exclude_for_issue_1572(self) -> None:
2014         # Exclude shouldn't touch files that were explicitly given to Black through the
2015         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
2016         # https://github.com/psf/black/issues/1572
2017         path = DATA_DIR / "include_exclude_tests"
2018         src = [path / "b/exclude/a.py"]
2019         expected = [path / "b/exclude/a.py"]
2020         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2021
2022     def test_gitignore_exclude(self) -> None:
2023         path = THIS_DIR / "data" / "include_exclude_tests"
2024         include = re.compile(r"\.pyi?$")
2025         exclude = re.compile(r"")
2026         report = black.Report()
2027         gitignore = PathSpec.from_lines(
2028             "gitwildmatch", ["exclude/", ".definitely_exclude"]
2029         )
2030         sources: List[Path] = []
2031         expected = [
2032             Path(path / "b/dont_exclude/a.py"),
2033             Path(path / "b/dont_exclude/a.pyi"),
2034         ]
2035         this_abs = THIS_DIR.resolve()
2036         sources.extend(
2037             black.gen_python_files(
2038                 path.iterdir(),
2039                 this_abs,
2040                 include,
2041                 exclude,
2042                 None,
2043                 None,
2044                 report,
2045                 {path: gitignore},
2046                 verbose=False,
2047                 quiet=False,
2048             )
2049         )
2050         assert sorted(expected) == sorted(sources)
2051
2052     def test_nested_gitignore(self) -> None:
2053         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
2054         include = re.compile(r"\.pyi?$")
2055         exclude = re.compile(r"")
2056         root_gitignore = black.files.get_gitignore(path)
2057         report = black.Report()
2058         expected: List[Path] = [
2059             Path(path / "x.py"),
2060             Path(path / "root/b.py"),
2061             Path(path / "root/c.py"),
2062             Path(path / "root/child/c.py"),
2063         ]
2064         this_abs = THIS_DIR.resolve()
2065         sources = list(
2066             black.gen_python_files(
2067                 path.iterdir(),
2068                 this_abs,
2069                 include,
2070                 exclude,
2071                 None,
2072                 None,
2073                 report,
2074                 {path: root_gitignore},
2075                 verbose=False,
2076                 quiet=False,
2077             )
2078         )
2079         assert sorted(expected) == sorted(sources)
2080
2081     def test_nested_gitignore_directly_in_source_directory(self) -> None:
2082         # https://github.com/psf/black/issues/2598
2083         path = Path(DATA_DIR / "nested_gitignore_tests")
2084         src = Path(path / "root" / "child")
2085         expected = [src / "a.py", src / "c.py"]
2086         assert_collected_sources([src], expected)
2087
2088     def test_invalid_gitignore(self) -> None:
2089         path = THIS_DIR / "data" / "invalid_gitignore_tests"
2090         empty_config = path / "pyproject.toml"
2091         result = BlackRunner().invoke(
2092             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2093         )
2094         assert result.exit_code == 1
2095         assert result.stderr_bytes is not None
2096
2097         gitignore = path / ".gitignore"
2098         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2099
2100     def test_invalid_nested_gitignore(self) -> None:
2101         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
2102         empty_config = path / "pyproject.toml"
2103         result = BlackRunner().invoke(
2104             black.main, ["--verbose", "--config", str(empty_config), str(path)]
2105         )
2106         assert result.exit_code == 1
2107         assert result.stderr_bytes is not None
2108
2109         gitignore = path / "a" / ".gitignore"
2110         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
2111
2112     def test_gitignore_that_ignores_subfolders(self) -> None:
2113         # If gitignore with */* is in root
2114         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
2115         expected = [root / "b.py"]
2116         ctx = FakeContext()
2117         ctx.obj["root"] = root
2118         assert_collected_sources([root], expected, ctx=ctx)
2119
2120         # If .gitignore with */* is nested
2121         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2122         expected = [
2123             root / "a.py",
2124             root / "subdir" / "b.py",
2125         ]
2126         ctx = FakeContext()
2127         ctx.obj["root"] = root
2128         assert_collected_sources([root], expected, ctx=ctx)
2129
2130         # If command is executed from outer dir
2131         root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
2132         target = root / "subdir"
2133         expected = [target / "b.py"]
2134         ctx = FakeContext()
2135         ctx.obj["root"] = root
2136         assert_collected_sources([target], expected, ctx=ctx)
2137
2138     def test_empty_include(self) -> None:
2139         path = DATA_DIR / "include_exclude_tests"
2140         src = [path]
2141         expected = [
2142             Path(path / "b/exclude/a.pie"),
2143             Path(path / "b/exclude/a.py"),
2144             Path(path / "b/exclude/a.pyi"),
2145             Path(path / "b/dont_exclude/a.pie"),
2146             Path(path / "b/dont_exclude/a.py"),
2147             Path(path / "b/dont_exclude/a.pyi"),
2148             Path(path / "b/.definitely_exclude/a.pie"),
2149             Path(path / "b/.definitely_exclude/a.py"),
2150             Path(path / "b/.definitely_exclude/a.pyi"),
2151             Path(path / ".gitignore"),
2152             Path(path / "pyproject.toml"),
2153         ]
2154         # Setting exclude explicitly to an empty string to block .gitignore usage.
2155         assert_collected_sources(src, expected, include="", exclude="")
2156
2157     def test_extend_exclude(self) -> None:
2158         path = DATA_DIR / "include_exclude_tests"
2159         src = [path]
2160         expected = [
2161             Path(path / "b/exclude/a.py"),
2162             Path(path / "b/dont_exclude/a.py"),
2163         ]
2164         assert_collected_sources(
2165             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
2166         )
2167
2168     @pytest.mark.incompatible_with_mypyc
2169     def test_symlink_out_of_root_directory(self) -> None:
2170         path = MagicMock()
2171         root = THIS_DIR.resolve()
2172         child = MagicMock()
2173         include = re.compile(black.DEFAULT_INCLUDES)
2174         exclude = re.compile(black.DEFAULT_EXCLUDES)
2175         report = black.Report()
2176         gitignore = PathSpec.from_lines("gitwildmatch", [])
2177         # `child` should behave like a symlink which resolved path is clearly
2178         # outside of the `root` directory.
2179         path.iterdir.return_value = [child]
2180         child.resolve.return_value = Path("/a/b/c")
2181         child.as_posix.return_value = "/a/b/c"
2182         try:
2183             list(
2184                 black.gen_python_files(
2185                     path.iterdir(),
2186                     root,
2187                     include,
2188                     exclude,
2189                     None,
2190                     None,
2191                     report,
2192                     {path: gitignore},
2193                     verbose=False,
2194                     quiet=False,
2195                 )
2196             )
2197         except ValueError as ve:
2198             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
2199         path.iterdir.assert_called_once()
2200         child.resolve.assert_called_once()
2201
2202     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2203     def test_get_sources_with_stdin(self) -> None:
2204         src = ["-"]
2205         expected = ["-"]
2206         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2207
2208     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2209     def test_get_sources_with_stdin_filename(self) -> None:
2210         src = ["-"]
2211         stdin_filename = str(THIS_DIR / "data/collections.py")
2212         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2213         assert_collected_sources(
2214             src,
2215             expected,
2216             exclude=r"/exclude/a\.py",
2217             stdin_filename=stdin_filename,
2218         )
2219
2220     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2221     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2222         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2223         # file being passed directly. This is the same as
2224         # test_exclude_for_issue_1572
2225         path = DATA_DIR / "include_exclude_tests"
2226         src = ["-"]
2227         stdin_filename = str(path / "b/exclude/a.py")
2228         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2229         assert_collected_sources(
2230             src,
2231             expected,
2232             exclude=r"/exclude/|a\.py",
2233             stdin_filename=stdin_filename,
2234         )
2235
2236     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2237     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2238         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2239         # file being passed directly. This is the same as
2240         # test_exclude_for_issue_1572
2241         src = ["-"]
2242         path = THIS_DIR / "data" / "include_exclude_tests"
2243         stdin_filename = str(path / "b/exclude/a.py")
2244         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2245         assert_collected_sources(
2246             src,
2247             expected,
2248             extend_exclude=r"/exclude/|a\.py",
2249             stdin_filename=stdin_filename,
2250         )
2251
2252     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
2253     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2254         # Force exclude should exclude the file when passing it through
2255         # stdin_filename
2256         path = THIS_DIR / "data" / "include_exclude_tests"
2257         stdin_filename = str(path / "b/exclude/a.py")
2258         assert_collected_sources(
2259             src=["-"],
2260             expected=[],
2261             force_exclude=r"/exclude/|a\.py",
2262             stdin_filename=stdin_filename,
2263         )
2264
2265
2266 try:
2267     with open(black.__file__, "r", encoding="utf-8") as _bf:
2268         black_source_lines = _bf.readlines()
2269 except UnicodeDecodeError:
2270     if not black.COMPILED:
2271         raise
2272
2273
2274 def tracefunc(
2275     frame: types.FrameType, event: str, arg: Any
2276 ) -> Callable[[types.FrameType, str, Any], Any]:
2277     """Show function calls `from black/__init__.py` as they happen.
2278
2279     Register this with `sys.settrace()` in a test you're debugging.
2280     """
2281     if event != "call":
2282         return tracefunc
2283
2284     stack = len(inspect.stack()) - 19
2285     stack *= 2
2286     filename = frame.f_code.co_filename
2287     lineno = frame.f_lineno
2288     func_sig_lineno = lineno - 1
2289     funcname = black_source_lines[func_sig_lineno].strip()
2290     while funcname.startswith("@"):
2291         func_sig_lineno += 1
2292         funcname = black_source_lines[func_sig_lineno].strip()
2293     if "black/__init__.py" in filename:
2294         print(f"{' ' * stack}{lineno}:{funcname}")
2295     return tracefunc