]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Include underlying error when AST safety check parsing fails (#2693)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args, catch_exceptions=False)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY27})
728         black.lib2to3_parse(straddling, {TargetVersion.PY36})
729         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
730
731         py2_only = "print x"
732         black.lib2to3_parse(py2_only)
733         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
734         with self.assertRaises(black.InvalidInput):
735             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
736         with self.assertRaises(black.InvalidInput):
737             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
738
739         py3_only = "exec(x, end=y)"
740         black.lib2to3_parse(py3_only)
741         with self.assertRaises(black.InvalidInput):
742             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
743         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
744         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
745
746     def test_get_features_used_decorator(self) -> None:
747         # Test the feature detection of new decorator syntax
748         # since this makes some test cases of test_get_features_used()
749         # fails if it fails, this is tested first so that a useful case
750         # is identified
751         simples, relaxed = read_data("decorators")
752         # skip explanation comments at the top of the file
753         for simple_test in simples.split("##")[1:]:
754             node = black.lib2to3_parse(simple_test)
755             decorator = str(node.children[0].children[0]).strip()
756             self.assertNotIn(
757                 Feature.RELAXED_DECORATORS,
758                 black.get_features_used(node),
759                 msg=(
760                     f"decorator '{decorator}' follows python<=3.8 syntax"
761                     "but is detected as 3.9+"
762                     # f"The full node is\n{node!r}"
763                 ),
764             )
765         # skip the '# output' comment at the top of the output part
766         for relaxed_test in relaxed.split("##")[1:]:
767             node = black.lib2to3_parse(relaxed_test)
768             decorator = str(node.children[0].children[0]).strip()
769             self.assertIn(
770                 Feature.RELAXED_DECORATORS,
771                 black.get_features_used(node),
772                 msg=(
773                     f"decorator '{decorator}' uses python3.9+ syntax"
774                     "but is detected as python<=3.8"
775                     # f"The full node is\n{node!r}"
776                 ),
777             )
778
779     def test_get_features_used(self) -> None:
780         node = black.lib2to3_parse("def f(*, arg): ...\n")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("def f(*, arg,): ...\n")
783         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
784         node = black.lib2to3_parse("f(*arg,)\n")
785         self.assertEqual(
786             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
787         )
788         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
789         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
790         node = black.lib2to3_parse("123_456\n")
791         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
792         node = black.lib2to3_parse("123456\n")
793         self.assertEqual(black.get_features_used(node), set())
794         source, expected = read_data("function")
795         node = black.lib2to3_parse(source)
796         expected_features = {
797             Feature.TRAILING_COMMA_IN_CALL,
798             Feature.TRAILING_COMMA_IN_DEF,
799             Feature.F_STRINGS,
800         }
801         self.assertEqual(black.get_features_used(node), expected_features)
802         node = black.lib2to3_parse(expected)
803         self.assertEqual(black.get_features_used(node), expected_features)
804         source, expected = read_data("expression")
805         node = black.lib2to3_parse(source)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse(expected)
808         self.assertEqual(black.get_features_used(node), set())
809         node = black.lib2to3_parse("lambda a, /, b: ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(a, /, b): ...")
812         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
813
814     def test_get_features_used_for_future_flags(self) -> None:
815         for src, features in [
816             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
817             (
818                 "from __future__ import (other, annotations)",
819                 {Feature.FUTURE_ANNOTATIONS},
820             ),
821             ("a = 1 + 2\nfrom something import annotations", set()),
822             ("from __future__ import x, y", set()),
823         ]:
824             with self.subTest(src=src, features=features):
825                 node = black.lib2to3_parse(src)
826                 future_imports = black.get_future_imports(node)
827                 self.assertEqual(
828                     black.get_features_used(node, future_imports=future_imports),
829                     features,
830                 )
831
832     def test_get_future_imports(self) -> None:
833         node = black.lib2to3_parse("\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from __future__ import black\n")
836         self.assertEqual({"black"}, black.get_future_imports(node))
837         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
838         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
839         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
840         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import multiple\nfrom __future__ import imports\n"
843         )
844         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
846         self.assertEqual({"black"}, black.get_future_imports(node))
847         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
848         self.assertEqual({"black"}, black.get_future_imports(node))
849         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
850         self.assertEqual(set(), black.get_future_imports(node))
851         node = black.lib2to3_parse("from some.module import black\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse(
854             "from __future__ import unicode_literals as _unicode_literals"
855         )
856         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
857         node = black.lib2to3_parse(
858             "from __future__ import unicode_literals as _lol, print"
859         )
860         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
861
862     @pytest.mark.incompatible_with_mypyc
863     def test_debug_visitor(self) -> None:
864         source, _ = read_data("debug_visitor.py")
865         expected, _ = read_data("debug_visitor.out")
866         out_lines = []
867         err_lines = []
868
869         def out(msg: str, **kwargs: Any) -> None:
870             out_lines.append(msg)
871
872         def err(msg: str, **kwargs: Any) -> None:
873             err_lines.append(msg)
874
875         with patch("black.debug.out", out):
876             DebugVisitor.show(source)
877         actual = "\n".join(out_lines) + "\n"
878         log_name = ""
879         if expected != actual:
880             log_name = black.dump_to_file(*out_lines)
881         self.assertEqual(
882             expected,
883             actual,
884             f"AST print out is different. Actual version dumped to {log_name}",
885         )
886
887     def test_format_file_contents(self) -> None:
888         empty = ""
889         mode = DEFAULT_MODE
890         with self.assertRaises(black.NothingChanged):
891             black.format_file_contents(empty, mode=mode, fast=False)
892         just_nl = "\n"
893         with self.assertRaises(black.NothingChanged):
894             black.format_file_contents(just_nl, mode=mode, fast=False)
895         same = "j = [1, 2, 3]\n"
896         with self.assertRaises(black.NothingChanged):
897             black.format_file_contents(same, mode=mode, fast=False)
898         different = "j = [1,2,3]"
899         expected = same
900         actual = black.format_file_contents(different, mode=mode, fast=False)
901         self.assertEqual(expected, actual)
902         invalid = "return if you can"
903         with self.assertRaises(black.InvalidInput) as e:
904             black.format_file_contents(invalid, mode=mode, fast=False)
905         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
906
907     def test_endmarker(self) -> None:
908         n = black.lib2to3_parse("\n")
909         self.assertEqual(n.type, black.syms.file_input)
910         self.assertEqual(len(n.children), 1)
911         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
912
913     @pytest.mark.incompatible_with_mypyc
914     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
915     def test_assertFormatEqual(self) -> None:
916         out_lines = []
917         err_lines = []
918
919         def out(msg: str, **kwargs: Any) -> None:
920             out_lines.append(msg)
921
922         def err(msg: str, **kwargs: Any) -> None:
923             err_lines.append(msg)
924
925         with patch("black.output._out", out), patch("black.output._err", err):
926             with self.assertRaises(AssertionError):
927                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
928
929         out_str = "".join(out_lines)
930         self.assertTrue("Expected tree:" in out_str)
931         self.assertTrue("Actual tree:" in out_str)
932         self.assertEqual("".join(err_lines), "")
933
934     @event_loop()
935     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
936     def test_works_in_mono_process_only_environment(self) -> None:
937         with cache_dir() as workspace:
938             for f in [
939                 (workspace / "one.py").resolve(),
940                 (workspace / "two.py").resolve(),
941             ]:
942                 f.write_text('print("hello")\n')
943             self.invokeBlack([str(workspace)])
944
945     @event_loop()
946     def test_check_diff_use_together(self) -> None:
947         with cache_dir():
948             # Files which will be reformatted.
949             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
950             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
951             # Files which will not be reformatted.
952             src2 = (THIS_DIR / "data" / "composition.py").resolve()
953             self.invokeBlack([str(src2), "--diff", "--check"])
954             # Multi file command.
955             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
956
957     def test_no_files(self) -> None:
958         with cache_dir():
959             # Without an argument, black exits with error code 0.
960             self.invokeBlack([])
961
962     def test_broken_symlink(self) -> None:
963         with cache_dir() as workspace:
964             symlink = workspace / "broken_link.py"
965             try:
966                 symlink.symlink_to("nonexistent.py")
967             except (OSError, NotImplementedError) as e:
968                 self.skipTest(f"Can't create symlinks: {e}")
969             self.invokeBlack([str(workspace.resolve())])
970
971     def test_single_file_force_pyi(self) -> None:
972         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
973         contents, expected = read_data("force_pyi")
974         with cache_dir() as workspace:
975             path = (workspace / "file.py").resolve()
976             with open(path, "w") as fh:
977                 fh.write(contents)
978             self.invokeBlack([str(path), "--pyi"])
979             with open(path, "r") as fh:
980                 actual = fh.read()
981             # verify cache with --pyi is separate
982             pyi_cache = black.read_cache(pyi_mode)
983             self.assertIn(str(path), pyi_cache)
984             normal_cache = black.read_cache(DEFAULT_MODE)
985             self.assertNotIn(str(path), normal_cache)
986         self.assertFormatEqual(expected, actual)
987         black.assert_equivalent(contents, actual)
988         black.assert_stable(contents, actual, pyi_mode)
989
990     @event_loop()
991     def test_multi_file_force_pyi(self) -> None:
992         reg_mode = DEFAULT_MODE
993         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
994         contents, expected = read_data("force_pyi")
995         with cache_dir() as workspace:
996             paths = [
997                 (workspace / "file1.py").resolve(),
998                 (workspace / "file2.py").resolve(),
999             ]
1000             for path in paths:
1001                 with open(path, "w") as fh:
1002                     fh.write(contents)
1003             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1004             for path in paths:
1005                 with open(path, "r") as fh:
1006                     actual = fh.read()
1007                 self.assertEqual(actual, expected)
1008             # verify cache with --pyi is separate
1009             pyi_cache = black.read_cache(pyi_mode)
1010             normal_cache = black.read_cache(reg_mode)
1011             for path in paths:
1012                 self.assertIn(str(path), pyi_cache)
1013                 self.assertNotIn(str(path), normal_cache)
1014
1015     def test_pipe_force_pyi(self) -> None:
1016         source, expected = read_data("force_pyi")
1017         result = CliRunner().invoke(
1018             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1019         )
1020         self.assertEqual(result.exit_code, 0)
1021         actual = result.output
1022         self.assertFormatEqual(actual, expected)
1023
1024     def test_single_file_force_py36(self) -> None:
1025         reg_mode = DEFAULT_MODE
1026         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1027         source, expected = read_data("force_py36")
1028         with cache_dir() as workspace:
1029             path = (workspace / "file.py").resolve()
1030             with open(path, "w") as fh:
1031                 fh.write(source)
1032             self.invokeBlack([str(path), *PY36_ARGS])
1033             with open(path, "r") as fh:
1034                 actual = fh.read()
1035             # verify cache with --target-version is separate
1036             py36_cache = black.read_cache(py36_mode)
1037             self.assertIn(str(path), py36_cache)
1038             normal_cache = black.read_cache(reg_mode)
1039             self.assertNotIn(str(path), normal_cache)
1040         self.assertEqual(actual, expected)
1041
1042     @event_loop()
1043     def test_multi_file_force_py36(self) -> None:
1044         reg_mode = DEFAULT_MODE
1045         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1046         source, expected = read_data("force_py36")
1047         with cache_dir() as workspace:
1048             paths = [
1049                 (workspace / "file1.py").resolve(),
1050                 (workspace / "file2.py").resolve(),
1051             ]
1052             for path in paths:
1053                 with open(path, "w") as fh:
1054                     fh.write(source)
1055             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1056             for path in paths:
1057                 with open(path, "r") as fh:
1058                     actual = fh.read()
1059                 self.assertEqual(actual, expected)
1060             # verify cache with --target-version is separate
1061             pyi_cache = black.read_cache(py36_mode)
1062             normal_cache = black.read_cache(reg_mode)
1063             for path in paths:
1064                 self.assertIn(str(path), pyi_cache)
1065                 self.assertNotIn(str(path), normal_cache)
1066
1067     def test_pipe_force_py36(self) -> None:
1068         source, expected = read_data("force_py36")
1069         result = CliRunner().invoke(
1070             black.main,
1071             ["-", "-q", "--target-version=py36"],
1072             input=BytesIO(source.encode("utf8")),
1073         )
1074         self.assertEqual(result.exit_code, 0)
1075         actual = result.output
1076         self.assertFormatEqual(actual, expected)
1077
1078     @pytest.mark.incompatible_with_mypyc
1079     def test_reformat_one_with_stdin(self) -> None:
1080         with patch(
1081             "black.format_stdin_to_stdout",
1082             return_value=lambda *args, **kwargs: black.Changed.YES,
1083         ) as fsts:
1084             report = MagicMock()
1085             path = Path("-")
1086             black.reformat_one(
1087                 path,
1088                 fast=True,
1089                 write_back=black.WriteBack.YES,
1090                 mode=DEFAULT_MODE,
1091                 report=report,
1092             )
1093             fsts.assert_called_once()
1094             report.done.assert_called_with(path, black.Changed.YES)
1095
1096     @pytest.mark.incompatible_with_mypyc
1097     def test_reformat_one_with_stdin_filename(self) -> None:
1098         with patch(
1099             "black.format_stdin_to_stdout",
1100             return_value=lambda *args, **kwargs: black.Changed.YES,
1101         ) as fsts:
1102             report = MagicMock()
1103             p = "foo.py"
1104             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1105             expected = Path(p)
1106             black.reformat_one(
1107                 path,
1108                 fast=True,
1109                 write_back=black.WriteBack.YES,
1110                 mode=DEFAULT_MODE,
1111                 report=report,
1112             )
1113             fsts.assert_called_once_with(
1114                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1115             )
1116             # __BLACK_STDIN_FILENAME__ should have been stripped
1117             report.done.assert_called_with(expected, black.Changed.YES)
1118
1119     @pytest.mark.incompatible_with_mypyc
1120     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1121         with patch(
1122             "black.format_stdin_to_stdout",
1123             return_value=lambda *args, **kwargs: black.Changed.YES,
1124         ) as fsts:
1125             report = MagicMock()
1126             p = "foo.pyi"
1127             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1128             expected = Path(p)
1129             black.reformat_one(
1130                 path,
1131                 fast=True,
1132                 write_back=black.WriteBack.YES,
1133                 mode=DEFAULT_MODE,
1134                 report=report,
1135             )
1136             fsts.assert_called_once_with(
1137                 fast=True,
1138                 write_back=black.WriteBack.YES,
1139                 mode=replace(DEFAULT_MODE, is_pyi=True),
1140             )
1141             # __BLACK_STDIN_FILENAME__ should have been stripped
1142             report.done.assert_called_with(expected, black.Changed.YES)
1143
1144     @pytest.mark.incompatible_with_mypyc
1145     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1146         with patch(
1147             "black.format_stdin_to_stdout",
1148             return_value=lambda *args, **kwargs: black.Changed.YES,
1149         ) as fsts:
1150             report = MagicMock()
1151             p = "foo.ipynb"
1152             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1153             expected = Path(p)
1154             black.reformat_one(
1155                 path,
1156                 fast=True,
1157                 write_back=black.WriteBack.YES,
1158                 mode=DEFAULT_MODE,
1159                 report=report,
1160             )
1161             fsts.assert_called_once_with(
1162                 fast=True,
1163                 write_back=black.WriteBack.YES,
1164                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1165             )
1166             # __BLACK_STDIN_FILENAME__ should have been stripped
1167             report.done.assert_called_with(expected, black.Changed.YES)
1168
1169     @pytest.mark.incompatible_with_mypyc
1170     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1171         with patch(
1172             "black.format_stdin_to_stdout",
1173             return_value=lambda *args, **kwargs: black.Changed.YES,
1174         ) as fsts:
1175             report = MagicMock()
1176             # Even with an existing file, since we are forcing stdin, black
1177             # should output to stdout and not modify the file inplace
1178             p = Path(str(THIS_DIR / "data/collections.py"))
1179             # Make sure is_file actually returns True
1180             self.assertTrue(p.is_file())
1181             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1182             expected = Path(p)
1183             black.reformat_one(
1184                 path,
1185                 fast=True,
1186                 write_back=black.WriteBack.YES,
1187                 mode=DEFAULT_MODE,
1188                 report=report,
1189             )
1190             fsts.assert_called_once()
1191             # __BLACK_STDIN_FILENAME__ should have been stripped
1192             report.done.assert_called_with(expected, black.Changed.YES)
1193
1194     def test_reformat_one_with_stdin_empty(self) -> None:
1195         output = io.StringIO()
1196         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1197             try:
1198                 black.format_stdin_to_stdout(
1199                     fast=True,
1200                     content="",
1201                     write_back=black.WriteBack.YES,
1202                     mode=DEFAULT_MODE,
1203                 )
1204             except io.UnsupportedOperation:
1205                 pass  # StringIO does not support detach
1206             assert output.getvalue() == ""
1207
1208     def test_invalid_cli_regex(self) -> None:
1209         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1210             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1211
1212     def test_required_version_matches_version(self) -> None:
1213         self.invokeBlack(
1214             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1215         )
1216
1217     def test_required_version_does_not_match_version(self) -> None:
1218         self.invokeBlack(
1219             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1220         )
1221
1222     def test_preserves_line_endings(self) -> None:
1223         with TemporaryDirectory() as workspace:
1224             test_file = Path(workspace) / "test.py"
1225             for nl in ["\n", "\r\n"]:
1226                 contents = nl.join(["def f(  ):", "    pass"])
1227                 test_file.write_bytes(contents.encode())
1228                 ff(test_file, write_back=black.WriteBack.YES)
1229                 updated_contents: bytes = test_file.read_bytes()
1230                 self.assertIn(nl.encode(), updated_contents)
1231                 if nl == "\n":
1232                     self.assertNotIn(b"\r\n", updated_contents)
1233
1234     def test_preserves_line_endings_via_stdin(self) -> None:
1235         for nl in ["\n", "\r\n"]:
1236             contents = nl.join(["def f(  ):", "    pass"])
1237             runner = BlackRunner()
1238             result = runner.invoke(
1239                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1240             )
1241             self.assertEqual(result.exit_code, 0)
1242             output = result.stdout_bytes
1243             self.assertIn(nl.encode("utf8"), output)
1244             if nl == "\n":
1245                 self.assertNotIn(b"\r\n", output)
1246
1247     def test_assert_equivalent_different_asts(self) -> None:
1248         with self.assertRaises(AssertionError):
1249             black.assert_equivalent("{}", "None")
1250
1251     def test_shhh_click(self) -> None:
1252         try:
1253             from click import _unicodefun
1254         except ModuleNotFoundError:
1255             self.skipTest("Incompatible Click version")
1256         if not hasattr(_unicodefun, "_verify_python3_env"):
1257             self.skipTest("Incompatible Click version")
1258         # First, let's see if Click is crashing with a preferred ASCII charset.
1259         with patch("locale.getpreferredencoding") as gpe:
1260             gpe.return_value = "ASCII"
1261             with self.assertRaises(RuntimeError):
1262                 _unicodefun._verify_python3_env()  # type: ignore
1263         # Now, let's silence Click...
1264         black.patch_click()
1265         # ...and confirm it's silent.
1266         with patch("locale.getpreferredencoding") as gpe:
1267             gpe.return_value = "ASCII"
1268             try:
1269                 _unicodefun._verify_python3_env()  # type: ignore
1270             except RuntimeError as re:
1271                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1272
1273     def test_root_logger_not_used_directly(self) -> None:
1274         def fail(*args: Any, **kwargs: Any) -> None:
1275             self.fail("Record created with root logger")
1276
1277         with patch.multiple(
1278             logging.root,
1279             debug=fail,
1280             info=fail,
1281             warning=fail,
1282             error=fail,
1283             critical=fail,
1284             log=fail,
1285         ):
1286             ff(THIS_DIR / "util.py")
1287
1288     def test_invalid_config_return_code(self) -> None:
1289         tmp_file = Path(black.dump_to_file())
1290         try:
1291             tmp_config = Path(black.dump_to_file())
1292             tmp_config.unlink()
1293             args = ["--config", str(tmp_config), str(tmp_file)]
1294             self.invokeBlack(args, exit_code=2, ignore_config=False)
1295         finally:
1296             tmp_file.unlink()
1297
1298     def test_parse_pyproject_toml(self) -> None:
1299         test_toml_file = THIS_DIR / "test.toml"
1300         config = black.parse_pyproject_toml(str(test_toml_file))
1301         self.assertEqual(config["verbose"], 1)
1302         self.assertEqual(config["check"], "no")
1303         self.assertEqual(config["diff"], "y")
1304         self.assertEqual(config["color"], True)
1305         self.assertEqual(config["line_length"], 79)
1306         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1307         self.assertEqual(config["exclude"], r"\.pyi?$")
1308         self.assertEqual(config["include"], r"\.py?$")
1309
1310     def test_read_pyproject_toml(self) -> None:
1311         test_toml_file = THIS_DIR / "test.toml"
1312         fake_ctx = FakeContext()
1313         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1314         config = fake_ctx.default_map
1315         self.assertEqual(config["verbose"], "1")
1316         self.assertEqual(config["check"], "no")
1317         self.assertEqual(config["diff"], "y")
1318         self.assertEqual(config["color"], "True")
1319         self.assertEqual(config["line_length"], "79")
1320         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1321         self.assertEqual(config["exclude"], r"\.pyi?$")
1322         self.assertEqual(config["include"], r"\.py?$")
1323
1324     @pytest.mark.incompatible_with_mypyc
1325     def test_find_project_root(self) -> None:
1326         with TemporaryDirectory() as workspace:
1327             root = Path(workspace)
1328             test_dir = root / "test"
1329             test_dir.mkdir()
1330
1331             src_dir = root / "src"
1332             src_dir.mkdir()
1333
1334             root_pyproject = root / "pyproject.toml"
1335             root_pyproject.touch()
1336             src_pyproject = src_dir / "pyproject.toml"
1337             src_pyproject.touch()
1338             src_python = src_dir / "foo.py"
1339             src_python.touch()
1340
1341             self.assertEqual(
1342                 black.find_project_root((src_dir, test_dir)), root.resolve()
1343             )
1344             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1345             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1346
1347     @patch(
1348         "black.files.find_user_pyproject_toml",
1349         black.files.find_user_pyproject_toml.__wrapped__,
1350     )
1351     def test_find_user_pyproject_toml_linux(self) -> None:
1352         if system() == "Windows":
1353             return
1354
1355         # Test if XDG_CONFIG_HOME is checked
1356         with TemporaryDirectory() as workspace:
1357             tmp_user_config = Path(workspace) / "black"
1358             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1359                 self.assertEqual(
1360                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1361                 )
1362
1363         # Test fallback for XDG_CONFIG_HOME
1364         with patch.dict("os.environ"):
1365             os.environ.pop("XDG_CONFIG_HOME", None)
1366             fallback_user_config = Path("~/.config").expanduser() / "black"
1367             self.assertEqual(
1368                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1369             )
1370
1371     def test_find_user_pyproject_toml_windows(self) -> None:
1372         if system() != "Windows":
1373             return
1374
1375         user_config_path = Path.home() / ".black"
1376         self.assertEqual(
1377             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1378         )
1379
1380     def test_bpo_33660_workaround(self) -> None:
1381         if system() == "Windows":
1382             return
1383
1384         # https://bugs.python.org/issue33660
1385         root = Path("/")
1386         with change_directory(root):
1387             path = Path("workspace") / "project"
1388             report = black.Report(verbose=True)
1389             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1390             self.assertEqual(normalized_path, "workspace/project")
1391
1392     def test_newline_comment_interaction(self) -> None:
1393         source = "class A:\\\r\n# type: ignore\n pass\n"
1394         output = black.format_str(source, mode=DEFAULT_MODE)
1395         black.assert_stable(source, output, mode=DEFAULT_MODE)
1396
1397     def test_bpo_2142_workaround(self) -> None:
1398
1399         # https://bugs.python.org/issue2142
1400
1401         source, _ = read_data("missing_final_newline.py")
1402         # read_data adds a trailing newline
1403         source = source.rstrip()
1404         expected, _ = read_data("missing_final_newline.diff")
1405         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1406         diff_header = re.compile(
1407             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1408             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1409         )
1410         try:
1411             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1412             self.assertEqual(result.exit_code, 0)
1413         finally:
1414             os.unlink(tmp_file)
1415         actual = result.output
1416         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1417         self.assertEqual(actual, expected)
1418
1419     @pytest.mark.python2
1420     def test_docstring_reformat_for_py27(self) -> None:
1421         """
1422         Check that stripping trailing whitespace from Python 2 docstrings
1423         doesn't trigger a "not equivalent to source" error
1424         """
1425         source = (
1426             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1427         )
1428         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1429
1430         result = BlackRunner().invoke(
1431             black.main,
1432             ["-", "-q", "--target-version=py27"],
1433             input=BytesIO(source),
1434         )
1435
1436         self.assertEqual(result.exit_code, 0)
1437         actual = result.stdout
1438         self.assertFormatEqual(actual, expected)
1439
1440     @staticmethod
1441     def compare_results(
1442         result: click.testing.Result, expected_value: str, expected_exit_code: int
1443     ) -> None:
1444         """Helper method to test the value and exit code of a click Result."""
1445         assert (
1446             result.output == expected_value
1447         ), "The output did not match the expected value."
1448         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1449
1450     def test_code_option(self) -> None:
1451         """Test the code option with no changes."""
1452         code = 'print("Hello world")\n'
1453         args = ["--code", code]
1454         result = CliRunner().invoke(black.main, args)
1455
1456         self.compare_results(result, code, 0)
1457
1458     def test_code_option_changed(self) -> None:
1459         """Test the code option when changes are required."""
1460         code = "print('hello world')"
1461         formatted = black.format_str(code, mode=DEFAULT_MODE)
1462
1463         args = ["--code", code]
1464         result = CliRunner().invoke(black.main, args)
1465
1466         self.compare_results(result, formatted, 0)
1467
1468     def test_code_option_check(self) -> None:
1469         """Test the code option when check is passed."""
1470         args = ["--check", "--code", 'print("Hello world")\n']
1471         result = CliRunner().invoke(black.main, args)
1472         self.compare_results(result, "", 0)
1473
1474     def test_code_option_check_changed(self) -> None:
1475         """Test the code option when changes are required, and check is passed."""
1476         args = ["--check", "--code", "print('hello world')"]
1477         result = CliRunner().invoke(black.main, args)
1478         self.compare_results(result, "", 1)
1479
1480     def test_code_option_diff(self) -> None:
1481         """Test the code option when diff is passed."""
1482         code = "print('hello world')"
1483         formatted = black.format_str(code, mode=DEFAULT_MODE)
1484         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1485
1486         args = ["--diff", "--code", code]
1487         result = CliRunner().invoke(black.main, args)
1488
1489         # Remove time from diff
1490         output = DIFF_TIME.sub("", result.output)
1491
1492         assert output == result_diff, "The output did not match the expected value."
1493         assert result.exit_code == 0, "The exit code is incorrect."
1494
1495     def test_code_option_color_diff(self) -> None:
1496         """Test the code option when color and diff are passed."""
1497         code = "print('hello world')"
1498         formatted = black.format_str(code, mode=DEFAULT_MODE)
1499
1500         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1501         result_diff = color_diff(result_diff)
1502
1503         args = ["--diff", "--color", "--code", code]
1504         result = CliRunner().invoke(black.main, args)
1505
1506         # Remove time from diff
1507         output = DIFF_TIME.sub("", result.output)
1508
1509         assert output == result_diff, "The output did not match the expected value."
1510         assert result.exit_code == 0, "The exit code is incorrect."
1511
1512     @pytest.mark.incompatible_with_mypyc
1513     def test_code_option_safe(self) -> None:
1514         """Test that the code option throws an error when the sanity checks fail."""
1515         # Patch black.assert_equivalent to ensure the sanity checks fail
1516         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1517             code = 'print("Hello world")'
1518             error_msg = f"{code}\nerror: cannot format <string>: \n"
1519
1520             args = ["--safe", "--code", code]
1521             result = CliRunner().invoke(black.main, args)
1522
1523             self.compare_results(result, error_msg, 123)
1524
1525     def test_code_option_fast(self) -> None:
1526         """Test that the code option ignores errors when the sanity checks fail."""
1527         # Patch black.assert_equivalent to ensure the sanity checks fail
1528         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1529             code = 'print("Hello world")'
1530             formatted = black.format_str(code, mode=DEFAULT_MODE)
1531
1532             args = ["--fast", "--code", code]
1533             result = CliRunner().invoke(black.main, args)
1534
1535             self.compare_results(result, formatted, 0)
1536
1537     @pytest.mark.incompatible_with_mypyc
1538     def test_code_option_config(self) -> None:
1539         """
1540         Test that the code option finds the pyproject.toml in the current directory.
1541         """
1542         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1543             args = ["--code", "print"]
1544             # This is the only directory known to contain a pyproject.toml
1545             with change_directory(PROJECT_ROOT):
1546                 CliRunner().invoke(black.main, args)
1547                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1548
1549             assert (
1550                 len(parse.mock_calls) >= 1
1551             ), "Expected config parse to be called with the current directory."
1552
1553             _, call_args, _ = parse.mock_calls[0]
1554             assert (
1555                 call_args[0].lower() == str(pyproject_path).lower()
1556             ), "Incorrect config loaded."
1557
1558     @pytest.mark.incompatible_with_mypyc
1559     def test_code_option_parent_config(self) -> None:
1560         """
1561         Test that the code option finds the pyproject.toml in the parent directory.
1562         """
1563         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1564             with change_directory(THIS_DIR):
1565                 args = ["--code", "print"]
1566                 CliRunner().invoke(black.main, args)
1567
1568                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1569                 assert (
1570                     len(parse.mock_calls) >= 1
1571                 ), "Expected config parse to be called with the current directory."
1572
1573                 _, call_args, _ = parse.mock_calls[0]
1574                 assert (
1575                     call_args[0].lower() == str(pyproject_path).lower()
1576                 ), "Incorrect config loaded."
1577
1578     def test_for_handled_unexpected_eof_error(self) -> None:
1579         """
1580         Test that an unexpected EOF SyntaxError is nicely presented.
1581         """
1582         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1583             black.lib2to3_parse("print(", {})
1584
1585         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1586
1587     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1588         with pytest.raises(AssertionError) as err:
1589             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1590
1591         err.match("--safe")
1592         # Unfortunately the SyntaxError message has changed in newer versions so we
1593         # can't match it directly.
1594         err.match("invalid character")
1595         err.match(r"\(<unknown>, line 1\)")
1596
1597
1598 class TestCaching:
1599     def test_cache_broken_file(self) -> None:
1600         mode = DEFAULT_MODE
1601         with cache_dir() as workspace:
1602             cache_file = get_cache_file(mode)
1603             cache_file.write_text("this is not a pickle")
1604             assert black.read_cache(mode) == {}
1605             src = (workspace / "test.py").resolve()
1606             src.write_text("print('hello')")
1607             invokeBlack([str(src)])
1608             cache = black.read_cache(mode)
1609             assert str(src) in cache
1610
1611     def test_cache_single_file_already_cached(self) -> None:
1612         mode = DEFAULT_MODE
1613         with cache_dir() as workspace:
1614             src = (workspace / "test.py").resolve()
1615             src.write_text("print('hello')")
1616             black.write_cache({}, [src], mode)
1617             invokeBlack([str(src)])
1618             assert src.read_text() == "print('hello')"
1619
1620     @event_loop()
1621     def test_cache_multiple_files(self) -> None:
1622         mode = DEFAULT_MODE
1623         with cache_dir() as workspace, patch(
1624             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1625         ):
1626             one = (workspace / "one.py").resolve()
1627             with one.open("w") as fobj:
1628                 fobj.write("print('hello')")
1629             two = (workspace / "two.py").resolve()
1630             with two.open("w") as fobj:
1631                 fobj.write("print('hello')")
1632             black.write_cache({}, [one], mode)
1633             invokeBlack([str(workspace)])
1634             with one.open("r") as fobj:
1635                 assert fobj.read() == "print('hello')"
1636             with two.open("r") as fobj:
1637                 assert fobj.read() == 'print("hello")\n'
1638             cache = black.read_cache(mode)
1639             assert str(one) in cache
1640             assert str(two) in cache
1641
1642     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1643     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1644         mode = DEFAULT_MODE
1645         with cache_dir() as workspace:
1646             src = (workspace / "test.py").resolve()
1647             with src.open("w") as fobj:
1648                 fobj.write("print('hello')")
1649             with patch("black.read_cache") as read_cache, patch(
1650                 "black.write_cache"
1651             ) as write_cache:
1652                 cmd = [str(src), "--diff"]
1653                 if color:
1654                     cmd.append("--color")
1655                 invokeBlack(cmd)
1656                 cache_file = get_cache_file(mode)
1657                 assert cache_file.exists() is False
1658                 write_cache.assert_not_called()
1659                 read_cache.assert_not_called()
1660
1661     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1662     @event_loop()
1663     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1664         with cache_dir() as workspace:
1665             for tag in range(0, 4):
1666                 src = (workspace / f"test{tag}.py").resolve()
1667                 with src.open("w") as fobj:
1668                     fobj.write("print('hello')")
1669             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1670                 cmd = ["--diff", str(workspace)]
1671                 if color:
1672                     cmd.append("--color")
1673                 invokeBlack(cmd, exit_code=0)
1674                 # this isn't quite doing what we want, but if it _isn't_
1675                 # called then we cannot be using the lock it provides
1676                 mgr.assert_called()
1677
1678     def test_no_cache_when_stdin(self) -> None:
1679         mode = DEFAULT_MODE
1680         with cache_dir():
1681             result = CliRunner().invoke(
1682                 black.main, ["-"], input=BytesIO(b"print('hello')")
1683             )
1684             assert not result.exit_code
1685             cache_file = get_cache_file(mode)
1686             assert not cache_file.exists()
1687
1688     def test_read_cache_no_cachefile(self) -> None:
1689         mode = DEFAULT_MODE
1690         with cache_dir():
1691             assert black.read_cache(mode) == {}
1692
1693     def test_write_cache_read_cache(self) -> None:
1694         mode = DEFAULT_MODE
1695         with cache_dir() as workspace:
1696             src = (workspace / "test.py").resolve()
1697             src.touch()
1698             black.write_cache({}, [src], mode)
1699             cache = black.read_cache(mode)
1700             assert str(src) in cache
1701             assert cache[str(src)] == black.get_cache_info(src)
1702
1703     def test_filter_cached(self) -> None:
1704         with TemporaryDirectory() as workspace:
1705             path = Path(workspace)
1706             uncached = (path / "uncached").resolve()
1707             cached = (path / "cached").resolve()
1708             cached_but_changed = (path / "changed").resolve()
1709             uncached.touch()
1710             cached.touch()
1711             cached_but_changed.touch()
1712             cache = {
1713                 str(cached): black.get_cache_info(cached),
1714                 str(cached_but_changed): (0.0, 0),
1715             }
1716             todo, done = black.filter_cached(
1717                 cache, {uncached, cached, cached_but_changed}
1718             )
1719             assert todo == {uncached, cached_but_changed}
1720             assert done == {cached}
1721
1722     def test_write_cache_creates_directory_if_needed(self) -> None:
1723         mode = DEFAULT_MODE
1724         with cache_dir(exists=False) as workspace:
1725             assert not workspace.exists()
1726             black.write_cache({}, [], mode)
1727             assert workspace.exists()
1728
1729     @event_loop()
1730     def test_failed_formatting_does_not_get_cached(self) -> None:
1731         mode = DEFAULT_MODE
1732         with cache_dir() as workspace, patch(
1733             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1734         ):
1735             failing = (workspace / "failing.py").resolve()
1736             with failing.open("w") as fobj:
1737                 fobj.write("not actually python")
1738             clean = (workspace / "clean.py").resolve()
1739             with clean.open("w") as fobj:
1740                 fobj.write('print("hello")\n')
1741             invokeBlack([str(workspace)], exit_code=123)
1742             cache = black.read_cache(mode)
1743             assert str(failing) not in cache
1744             assert str(clean) in cache
1745
1746     def test_write_cache_write_fail(self) -> None:
1747         mode = DEFAULT_MODE
1748         with cache_dir(), patch.object(Path, "open") as mock:
1749             mock.side_effect = OSError
1750             black.write_cache({}, [], mode)
1751
1752     def test_read_cache_line_lengths(self) -> None:
1753         mode = DEFAULT_MODE
1754         short_mode = replace(DEFAULT_MODE, line_length=1)
1755         with cache_dir() as workspace:
1756             path = (workspace / "file.py").resolve()
1757             path.touch()
1758             black.write_cache({}, [path], mode)
1759             one = black.read_cache(mode)
1760             assert str(path) in one
1761             two = black.read_cache(short_mode)
1762             assert str(path) not in two
1763
1764
1765 def assert_collected_sources(
1766     src: Sequence[Union[str, Path]],
1767     expected: Sequence[Union[str, Path]],
1768     *,
1769     exclude: Optional[str] = None,
1770     include: Optional[str] = None,
1771     extend_exclude: Optional[str] = None,
1772     force_exclude: Optional[str] = None,
1773     stdin_filename: Optional[str] = None,
1774 ) -> None:
1775     gs_src = tuple(str(Path(s)) for s in src)
1776     gs_expected = [Path(s) for s in expected]
1777     gs_exclude = None if exclude is None else compile_pattern(exclude)
1778     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1779     gs_extend_exclude = (
1780         None if extend_exclude is None else compile_pattern(extend_exclude)
1781     )
1782     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1783     collected = black.get_sources(
1784         ctx=FakeContext(),
1785         src=gs_src,
1786         quiet=False,
1787         verbose=False,
1788         include=gs_include,
1789         exclude=gs_exclude,
1790         extend_exclude=gs_extend_exclude,
1791         force_exclude=gs_force_exclude,
1792         report=black.Report(),
1793         stdin_filename=stdin_filename,
1794     )
1795     assert sorted(collected) == sorted(gs_expected)
1796
1797
1798 class TestFileCollection:
1799     def test_include_exclude(self) -> None:
1800         path = THIS_DIR / "data" / "include_exclude_tests"
1801         src = [path]
1802         expected = [
1803             Path(path / "b/dont_exclude/a.py"),
1804             Path(path / "b/dont_exclude/a.pyi"),
1805         ]
1806         assert_collected_sources(
1807             src,
1808             expected,
1809             include=r"\.pyi?$",
1810             exclude=r"/exclude/|/\.definitely_exclude/",
1811         )
1812
1813     def test_gitignore_used_as_default(self) -> None:
1814         base = Path(DATA_DIR / "include_exclude_tests")
1815         expected = [
1816             base / "b/.definitely_exclude/a.py",
1817             base / "b/.definitely_exclude/a.pyi",
1818         ]
1819         src = [base / "b/"]
1820         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1821
1822     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1823     def test_exclude_for_issue_1572(self) -> None:
1824         # Exclude shouldn't touch files that were explicitly given to Black through the
1825         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1826         # https://github.com/psf/black/issues/1572
1827         path = DATA_DIR / "include_exclude_tests"
1828         src = [path / "b/exclude/a.py"]
1829         expected = [path / "b/exclude/a.py"]
1830         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1831
1832     def test_gitignore_exclude(self) -> None:
1833         path = THIS_DIR / "data" / "include_exclude_tests"
1834         include = re.compile(r"\.pyi?$")
1835         exclude = re.compile(r"")
1836         report = black.Report()
1837         gitignore = PathSpec.from_lines(
1838             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1839         )
1840         sources: List[Path] = []
1841         expected = [
1842             Path(path / "b/dont_exclude/a.py"),
1843             Path(path / "b/dont_exclude/a.pyi"),
1844         ]
1845         this_abs = THIS_DIR.resolve()
1846         sources.extend(
1847             black.gen_python_files(
1848                 path.iterdir(),
1849                 this_abs,
1850                 include,
1851                 exclude,
1852                 None,
1853                 None,
1854                 report,
1855                 gitignore,
1856                 verbose=False,
1857                 quiet=False,
1858             )
1859         )
1860         assert sorted(expected) == sorted(sources)
1861
1862     def test_nested_gitignore(self) -> None:
1863         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1864         include = re.compile(r"\.pyi?$")
1865         exclude = re.compile(r"")
1866         root_gitignore = black.files.get_gitignore(path)
1867         report = black.Report()
1868         expected: List[Path] = [
1869             Path(path / "x.py"),
1870             Path(path / "root/b.py"),
1871             Path(path / "root/c.py"),
1872             Path(path / "root/child/c.py"),
1873         ]
1874         this_abs = THIS_DIR.resolve()
1875         sources = list(
1876             black.gen_python_files(
1877                 path.iterdir(),
1878                 this_abs,
1879                 include,
1880                 exclude,
1881                 None,
1882                 None,
1883                 report,
1884                 root_gitignore,
1885                 verbose=False,
1886                 quiet=False,
1887             )
1888         )
1889         assert sorted(expected) == sorted(sources)
1890
1891     def test_invalid_gitignore(self) -> None:
1892         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1893         empty_config = path / "pyproject.toml"
1894         result = BlackRunner().invoke(
1895             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1896         )
1897         assert result.exit_code == 1
1898         assert result.stderr_bytes is not None
1899
1900         gitignore = path / ".gitignore"
1901         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1902
1903     def test_invalid_nested_gitignore(self) -> None:
1904         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1905         empty_config = path / "pyproject.toml"
1906         result = BlackRunner().invoke(
1907             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1908         )
1909         assert result.exit_code == 1
1910         assert result.stderr_bytes is not None
1911
1912         gitignore = path / "a" / ".gitignore"
1913         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1914
1915     def test_empty_include(self) -> None:
1916         path = DATA_DIR / "include_exclude_tests"
1917         src = [path]
1918         expected = [
1919             Path(path / "b/exclude/a.pie"),
1920             Path(path / "b/exclude/a.py"),
1921             Path(path / "b/exclude/a.pyi"),
1922             Path(path / "b/dont_exclude/a.pie"),
1923             Path(path / "b/dont_exclude/a.py"),
1924             Path(path / "b/dont_exclude/a.pyi"),
1925             Path(path / "b/.definitely_exclude/a.pie"),
1926             Path(path / "b/.definitely_exclude/a.py"),
1927             Path(path / "b/.definitely_exclude/a.pyi"),
1928             Path(path / ".gitignore"),
1929             Path(path / "pyproject.toml"),
1930         ]
1931         # Setting exclude explicitly to an empty string to block .gitignore usage.
1932         assert_collected_sources(src, expected, include="", exclude="")
1933
1934     def test_extend_exclude(self) -> None:
1935         path = DATA_DIR / "include_exclude_tests"
1936         src = [path]
1937         expected = [
1938             Path(path / "b/exclude/a.py"),
1939             Path(path / "b/dont_exclude/a.py"),
1940         ]
1941         assert_collected_sources(
1942             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1943         )
1944
1945     @pytest.mark.incompatible_with_mypyc
1946     def test_symlink_out_of_root_directory(self) -> None:
1947         path = MagicMock()
1948         root = THIS_DIR.resolve()
1949         child = MagicMock()
1950         include = re.compile(black.DEFAULT_INCLUDES)
1951         exclude = re.compile(black.DEFAULT_EXCLUDES)
1952         report = black.Report()
1953         gitignore = PathSpec.from_lines("gitwildmatch", [])
1954         # `child` should behave like a symlink which resolved path is clearly
1955         # outside of the `root` directory.
1956         path.iterdir.return_value = [child]
1957         child.resolve.return_value = Path("/a/b/c")
1958         child.as_posix.return_value = "/a/b/c"
1959         child.is_symlink.return_value = True
1960         try:
1961             list(
1962                 black.gen_python_files(
1963                     path.iterdir(),
1964                     root,
1965                     include,
1966                     exclude,
1967                     None,
1968                     None,
1969                     report,
1970                     gitignore,
1971                     verbose=False,
1972                     quiet=False,
1973                 )
1974             )
1975         except ValueError as ve:
1976             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1977         path.iterdir.assert_called_once()
1978         child.resolve.assert_called_once()
1979         child.is_symlink.assert_called_once()
1980         # `child` should behave like a strange file which resolved path is clearly
1981         # outside of the `root` directory.
1982         child.is_symlink.return_value = False
1983         with pytest.raises(ValueError):
1984             list(
1985                 black.gen_python_files(
1986                     path.iterdir(),
1987                     root,
1988                     include,
1989                     exclude,
1990                     None,
1991                     None,
1992                     report,
1993                     gitignore,
1994                     verbose=False,
1995                     quiet=False,
1996                 )
1997             )
1998         path.iterdir.assert_called()
1999         assert path.iterdir.call_count == 2
2000         child.resolve.assert_called()
2001         assert child.resolve.call_count == 2
2002         child.is_symlink.assert_called()
2003         assert child.is_symlink.call_count == 2
2004
2005     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2006     def test_get_sources_with_stdin(self) -> None:
2007         src = ["-"]
2008         expected = ["-"]
2009         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2010
2011     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2012     def test_get_sources_with_stdin_filename(self) -> None:
2013         src = ["-"]
2014         stdin_filename = str(THIS_DIR / "data/collections.py")
2015         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2016         assert_collected_sources(
2017             src,
2018             expected,
2019             exclude=r"/exclude/a\.py",
2020             stdin_filename=stdin_filename,
2021         )
2022
2023     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2024     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2025         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2026         # file being passed directly. This is the same as
2027         # test_exclude_for_issue_1572
2028         path = DATA_DIR / "include_exclude_tests"
2029         src = ["-"]
2030         stdin_filename = str(path / "b/exclude/a.py")
2031         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2032         assert_collected_sources(
2033             src,
2034             expected,
2035             exclude=r"/exclude/|a\.py",
2036             stdin_filename=stdin_filename,
2037         )
2038
2039     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2040     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2041         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2042         # file being passed directly. This is the same as
2043         # test_exclude_for_issue_1572
2044         src = ["-"]
2045         path = THIS_DIR / "data" / "include_exclude_tests"
2046         stdin_filename = str(path / "b/exclude/a.py")
2047         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2048         assert_collected_sources(
2049             src,
2050             expected,
2051             extend_exclude=r"/exclude/|a\.py",
2052             stdin_filename=stdin_filename,
2053         )
2054
2055     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2056     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2057         # Force exclude should exclude the file when passing it through
2058         # stdin_filename
2059         path = THIS_DIR / "data" / "include_exclude_tests"
2060         stdin_filename = str(path / "b/exclude/a.py")
2061         assert_collected_sources(
2062             src=["-"],
2063             expected=[],
2064             force_exclude=r"/exclude/|a\.py",
2065             stdin_filename=stdin_filename,
2066         )
2067
2068
2069 @pytest.mark.python2
2070 @pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
2071 def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
2072     args = [
2073         "--config",
2074         str(THIS_DIR / "empty.toml"),
2075         str(DATA_DIR / "python2.py"),
2076         "--check",
2077     ]
2078     if explicit:
2079         args.append("--target-version=py27")
2080     with cache_dir():
2081         result = BlackRunner().invoke(black.main, args)
2082     assert "DEPRECATION: Python 2 support will be removed" in result.stderr
2083
2084
2085 @pytest.mark.python2
2086 def test_python_2_deprecation_autodetection_extended() -> None:
2087     # this test has a similar construction to test_get_features_used_decorator
2088     python2, non_python2 = read_data("python2_detection")
2089     for python2_case in python2.split("###"):
2090         node = black.lib2to3_parse(python2_case)
2091         assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
2092     for non_python2_case in non_python2.split("###"):
2093         node = black.lib2to3_parse(non_python2_case)
2094         assert black.detect_target_versions(node) != {
2095             TargetVersion.PY27
2096         }, non_python2_case
2097
2098
2099 try:
2100     with open(black.__file__, "r", encoding="utf-8") as _bf:
2101         black_source_lines = _bf.readlines()
2102 except UnicodeDecodeError:
2103     if not black.COMPILED:
2104         raise
2105
2106
2107 def tracefunc(
2108     frame: types.FrameType, event: str, arg: Any
2109 ) -> Callable[[types.FrameType, str, Any], Any]:
2110     """Show function calls `from black/__init__.py` as they happen.
2111
2112     Register this with `sys.settrace()` in a test you're debugging.
2113     """
2114     if event != "call":
2115         return tracefunc
2116
2117     stack = len(inspect.stack()) - 19
2118     stack *= 2
2119     filename = frame.f_code.co_filename
2120     lineno = frame.f_lineno
2121     func_sig_lineno = lineno - 1
2122     funcname = black_source_lines[func_sig_lineno].strip()
2123     while funcname.startswith("@"):
2124         func_sig_lineno += 1
2125         funcname = black_source_lines[func_sig_lineno].strip()
2126     if "black/__init__.py" in filename:
2127         print(f"{' ' * stack}{lineno}:{funcname}")
2128     return tracefunc