]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Unpacking on flow constructs (return/yield) now implies 3.8+ (#2700)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args, catch_exceptions=False)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY27})
728         black.lib2to3_parse(straddling, {TargetVersion.PY36})
729         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
730
731         py2_only = "print x"
732         black.lib2to3_parse(py2_only)
733         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
734         with self.assertRaises(black.InvalidInput):
735             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
736         with self.assertRaises(black.InvalidInput):
737             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
738
739         py3_only = "exec(x, end=y)"
740         black.lib2to3_parse(py3_only)
741         with self.assertRaises(black.InvalidInput):
742             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
743         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
744         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
745
746     def test_get_features_used_decorator(self) -> None:
747         # Test the feature detection of new decorator syntax
748         # since this makes some test cases of test_get_features_used()
749         # fails if it fails, this is tested first so that a useful case
750         # is identified
751         simples, relaxed = read_data("decorators")
752         # skip explanation comments at the top of the file
753         for simple_test in simples.split("##")[1:]:
754             node = black.lib2to3_parse(simple_test)
755             decorator = str(node.children[0].children[0]).strip()
756             self.assertNotIn(
757                 Feature.RELAXED_DECORATORS,
758                 black.get_features_used(node),
759                 msg=(
760                     f"decorator '{decorator}' follows python<=3.8 syntax"
761                     "but is detected as 3.9+"
762                     # f"The full node is\n{node!r}"
763                 ),
764             )
765         # skip the '# output' comment at the top of the output part
766         for relaxed_test in relaxed.split("##")[1:]:
767             node = black.lib2to3_parse(relaxed_test)
768             decorator = str(node.children[0].children[0]).strip()
769             self.assertIn(
770                 Feature.RELAXED_DECORATORS,
771                 black.get_features_used(node),
772                 msg=(
773                     f"decorator '{decorator}' uses python3.9+ syntax"
774                     "but is detected as python<=3.8"
775                     # f"The full node is\n{node!r}"
776                 ),
777             )
778
779     def test_get_features_used(self) -> None:
780         node = black.lib2to3_parse("def f(*, arg): ...\n")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("def f(*, arg,): ...\n")
783         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
784         node = black.lib2to3_parse("f(*arg,)\n")
785         self.assertEqual(
786             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
787         )
788         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
789         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
790         node = black.lib2to3_parse("123_456\n")
791         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
792         node = black.lib2to3_parse("123456\n")
793         self.assertEqual(black.get_features_used(node), set())
794         source, expected = read_data("function")
795         node = black.lib2to3_parse(source)
796         expected_features = {
797             Feature.TRAILING_COMMA_IN_CALL,
798             Feature.TRAILING_COMMA_IN_DEF,
799             Feature.F_STRINGS,
800         }
801         self.assertEqual(black.get_features_used(node), expected_features)
802         node = black.lib2to3_parse(expected)
803         self.assertEqual(black.get_features_used(node), expected_features)
804         source, expected = read_data("expression")
805         node = black.lib2to3_parse(source)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse(expected)
808         self.assertEqual(black.get_features_used(node), set())
809         node = black.lib2to3_parse("lambda a, /, b: ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(a, /, b): ...")
812         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
813         node = black.lib2to3_parse("def fn(): yield a, b")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("def fn(): return a, b")
816         self.assertEqual(black.get_features_used(node), set())
817         node = black.lib2to3_parse("def fn(): yield *b, c")
818         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
819         node = black.lib2to3_parse("def fn(): return a, *b, c")
820         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
821
822     def test_get_features_used_for_future_flags(self) -> None:
823         for src, features in [
824             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
825             (
826                 "from __future__ import (other, annotations)",
827                 {Feature.FUTURE_ANNOTATIONS},
828             ),
829             ("a = 1 + 2\nfrom something import annotations", set()),
830             ("from __future__ import x, y", set()),
831         ]:
832             with self.subTest(src=src, features=features):
833                 node = black.lib2to3_parse(src)
834                 future_imports = black.get_future_imports(node)
835                 self.assertEqual(
836                     black.get_features_used(node, future_imports=future_imports),
837                     features,
838                 )
839
840     def test_get_future_imports(self) -> None:
841         node = black.lib2to3_parse("\n")
842         self.assertEqual(set(), black.get_future_imports(node))
843         node = black.lib2to3_parse("from __future__ import black\n")
844         self.assertEqual({"black"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
846         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
847         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
848         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
849         node = black.lib2to3_parse(
850             "from __future__ import multiple\nfrom __future__ import imports\n"
851         )
852         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
853         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
858         self.assertEqual(set(), black.get_future_imports(node))
859         node = black.lib2to3_parse("from some.module import black\n")
860         self.assertEqual(set(), black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import unicode_literals as _unicode_literals"
863         )
864         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
865         node = black.lib2to3_parse(
866             "from __future__ import unicode_literals as _lol, print"
867         )
868         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
869
870     @pytest.mark.incompatible_with_mypyc
871     def test_debug_visitor(self) -> None:
872         source, _ = read_data("debug_visitor.py")
873         expected, _ = read_data("debug_visitor.out")
874         out_lines = []
875         err_lines = []
876
877         def out(msg: str, **kwargs: Any) -> None:
878             out_lines.append(msg)
879
880         def err(msg: str, **kwargs: Any) -> None:
881             err_lines.append(msg)
882
883         with patch("black.debug.out", out):
884             DebugVisitor.show(source)
885         actual = "\n".join(out_lines) + "\n"
886         log_name = ""
887         if expected != actual:
888             log_name = black.dump_to_file(*out_lines)
889         self.assertEqual(
890             expected,
891             actual,
892             f"AST print out is different. Actual version dumped to {log_name}",
893         )
894
895     def test_format_file_contents(self) -> None:
896         empty = ""
897         mode = DEFAULT_MODE
898         with self.assertRaises(black.NothingChanged):
899             black.format_file_contents(empty, mode=mode, fast=False)
900         just_nl = "\n"
901         with self.assertRaises(black.NothingChanged):
902             black.format_file_contents(just_nl, mode=mode, fast=False)
903         same = "j = [1, 2, 3]\n"
904         with self.assertRaises(black.NothingChanged):
905             black.format_file_contents(same, mode=mode, fast=False)
906         different = "j = [1,2,3]"
907         expected = same
908         actual = black.format_file_contents(different, mode=mode, fast=False)
909         self.assertEqual(expected, actual)
910         invalid = "return if you can"
911         with self.assertRaises(black.InvalidInput) as e:
912             black.format_file_contents(invalid, mode=mode, fast=False)
913         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
914
915     def test_endmarker(self) -> None:
916         n = black.lib2to3_parse("\n")
917         self.assertEqual(n.type, black.syms.file_input)
918         self.assertEqual(len(n.children), 1)
919         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
920
921     @pytest.mark.incompatible_with_mypyc
922     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
923     def test_assertFormatEqual(self) -> None:
924         out_lines = []
925         err_lines = []
926
927         def out(msg: str, **kwargs: Any) -> None:
928             out_lines.append(msg)
929
930         def err(msg: str, **kwargs: Any) -> None:
931             err_lines.append(msg)
932
933         with patch("black.output._out", out), patch("black.output._err", err):
934             with self.assertRaises(AssertionError):
935                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
936
937         out_str = "".join(out_lines)
938         self.assertTrue("Expected tree:" in out_str)
939         self.assertTrue("Actual tree:" in out_str)
940         self.assertEqual("".join(err_lines), "")
941
942     @event_loop()
943     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
944     def test_works_in_mono_process_only_environment(self) -> None:
945         with cache_dir() as workspace:
946             for f in [
947                 (workspace / "one.py").resolve(),
948                 (workspace / "two.py").resolve(),
949             ]:
950                 f.write_text('print("hello")\n')
951             self.invokeBlack([str(workspace)])
952
953     @event_loop()
954     def test_check_diff_use_together(self) -> None:
955         with cache_dir():
956             # Files which will be reformatted.
957             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
958             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
959             # Files which will not be reformatted.
960             src2 = (THIS_DIR / "data" / "composition.py").resolve()
961             self.invokeBlack([str(src2), "--diff", "--check"])
962             # Multi file command.
963             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
964
965     def test_no_files(self) -> None:
966         with cache_dir():
967             # Without an argument, black exits with error code 0.
968             self.invokeBlack([])
969
970     def test_broken_symlink(self) -> None:
971         with cache_dir() as workspace:
972             symlink = workspace / "broken_link.py"
973             try:
974                 symlink.symlink_to("nonexistent.py")
975             except (OSError, NotImplementedError) as e:
976                 self.skipTest(f"Can't create symlinks: {e}")
977             self.invokeBlack([str(workspace.resolve())])
978
979     def test_single_file_force_pyi(self) -> None:
980         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
981         contents, expected = read_data("force_pyi")
982         with cache_dir() as workspace:
983             path = (workspace / "file.py").resolve()
984             with open(path, "w") as fh:
985                 fh.write(contents)
986             self.invokeBlack([str(path), "--pyi"])
987             with open(path, "r") as fh:
988                 actual = fh.read()
989             # verify cache with --pyi is separate
990             pyi_cache = black.read_cache(pyi_mode)
991             self.assertIn(str(path), pyi_cache)
992             normal_cache = black.read_cache(DEFAULT_MODE)
993             self.assertNotIn(str(path), normal_cache)
994         self.assertFormatEqual(expected, actual)
995         black.assert_equivalent(contents, actual)
996         black.assert_stable(contents, actual, pyi_mode)
997
998     @event_loop()
999     def test_multi_file_force_pyi(self) -> None:
1000         reg_mode = DEFAULT_MODE
1001         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1002         contents, expected = read_data("force_pyi")
1003         with cache_dir() as workspace:
1004             paths = [
1005                 (workspace / "file1.py").resolve(),
1006                 (workspace / "file2.py").resolve(),
1007             ]
1008             for path in paths:
1009                 with open(path, "w") as fh:
1010                     fh.write(contents)
1011             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1012             for path in paths:
1013                 with open(path, "r") as fh:
1014                     actual = fh.read()
1015                 self.assertEqual(actual, expected)
1016             # verify cache with --pyi is separate
1017             pyi_cache = black.read_cache(pyi_mode)
1018             normal_cache = black.read_cache(reg_mode)
1019             for path in paths:
1020                 self.assertIn(str(path), pyi_cache)
1021                 self.assertNotIn(str(path), normal_cache)
1022
1023     def test_pipe_force_pyi(self) -> None:
1024         source, expected = read_data("force_pyi")
1025         result = CliRunner().invoke(
1026             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1027         )
1028         self.assertEqual(result.exit_code, 0)
1029         actual = result.output
1030         self.assertFormatEqual(actual, expected)
1031
1032     def test_single_file_force_py36(self) -> None:
1033         reg_mode = DEFAULT_MODE
1034         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1035         source, expected = read_data("force_py36")
1036         with cache_dir() as workspace:
1037             path = (workspace / "file.py").resolve()
1038             with open(path, "w") as fh:
1039                 fh.write(source)
1040             self.invokeBlack([str(path), *PY36_ARGS])
1041             with open(path, "r") as fh:
1042                 actual = fh.read()
1043             # verify cache with --target-version is separate
1044             py36_cache = black.read_cache(py36_mode)
1045             self.assertIn(str(path), py36_cache)
1046             normal_cache = black.read_cache(reg_mode)
1047             self.assertNotIn(str(path), normal_cache)
1048         self.assertEqual(actual, expected)
1049
1050     @event_loop()
1051     def test_multi_file_force_py36(self) -> None:
1052         reg_mode = DEFAULT_MODE
1053         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1054         source, expected = read_data("force_py36")
1055         with cache_dir() as workspace:
1056             paths = [
1057                 (workspace / "file1.py").resolve(),
1058                 (workspace / "file2.py").resolve(),
1059             ]
1060             for path in paths:
1061                 with open(path, "w") as fh:
1062                     fh.write(source)
1063             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1064             for path in paths:
1065                 with open(path, "r") as fh:
1066                     actual = fh.read()
1067                 self.assertEqual(actual, expected)
1068             # verify cache with --target-version is separate
1069             pyi_cache = black.read_cache(py36_mode)
1070             normal_cache = black.read_cache(reg_mode)
1071             for path in paths:
1072                 self.assertIn(str(path), pyi_cache)
1073                 self.assertNotIn(str(path), normal_cache)
1074
1075     def test_pipe_force_py36(self) -> None:
1076         source, expected = read_data("force_py36")
1077         result = CliRunner().invoke(
1078             black.main,
1079             ["-", "-q", "--target-version=py36"],
1080             input=BytesIO(source.encode("utf8")),
1081         )
1082         self.assertEqual(result.exit_code, 0)
1083         actual = result.output
1084         self.assertFormatEqual(actual, expected)
1085
1086     @pytest.mark.incompatible_with_mypyc
1087     def test_reformat_one_with_stdin(self) -> None:
1088         with patch(
1089             "black.format_stdin_to_stdout",
1090             return_value=lambda *args, **kwargs: black.Changed.YES,
1091         ) as fsts:
1092             report = MagicMock()
1093             path = Path("-")
1094             black.reformat_one(
1095                 path,
1096                 fast=True,
1097                 write_back=black.WriteBack.YES,
1098                 mode=DEFAULT_MODE,
1099                 report=report,
1100             )
1101             fsts.assert_called_once()
1102             report.done.assert_called_with(path, black.Changed.YES)
1103
1104     @pytest.mark.incompatible_with_mypyc
1105     def test_reformat_one_with_stdin_filename(self) -> None:
1106         with patch(
1107             "black.format_stdin_to_stdout",
1108             return_value=lambda *args, **kwargs: black.Changed.YES,
1109         ) as fsts:
1110             report = MagicMock()
1111             p = "foo.py"
1112             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1113             expected = Path(p)
1114             black.reformat_one(
1115                 path,
1116                 fast=True,
1117                 write_back=black.WriteBack.YES,
1118                 mode=DEFAULT_MODE,
1119                 report=report,
1120             )
1121             fsts.assert_called_once_with(
1122                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1123             )
1124             # __BLACK_STDIN_FILENAME__ should have been stripped
1125             report.done.assert_called_with(expected, black.Changed.YES)
1126
1127     @pytest.mark.incompatible_with_mypyc
1128     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1129         with patch(
1130             "black.format_stdin_to_stdout",
1131             return_value=lambda *args, **kwargs: black.Changed.YES,
1132         ) as fsts:
1133             report = MagicMock()
1134             p = "foo.pyi"
1135             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1136             expected = Path(p)
1137             black.reformat_one(
1138                 path,
1139                 fast=True,
1140                 write_back=black.WriteBack.YES,
1141                 mode=DEFAULT_MODE,
1142                 report=report,
1143             )
1144             fsts.assert_called_once_with(
1145                 fast=True,
1146                 write_back=black.WriteBack.YES,
1147                 mode=replace(DEFAULT_MODE, is_pyi=True),
1148             )
1149             # __BLACK_STDIN_FILENAME__ should have been stripped
1150             report.done.assert_called_with(expected, black.Changed.YES)
1151
1152     @pytest.mark.incompatible_with_mypyc
1153     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1154         with patch(
1155             "black.format_stdin_to_stdout",
1156             return_value=lambda *args, **kwargs: black.Changed.YES,
1157         ) as fsts:
1158             report = MagicMock()
1159             p = "foo.ipynb"
1160             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1161             expected = Path(p)
1162             black.reformat_one(
1163                 path,
1164                 fast=True,
1165                 write_back=black.WriteBack.YES,
1166                 mode=DEFAULT_MODE,
1167                 report=report,
1168             )
1169             fsts.assert_called_once_with(
1170                 fast=True,
1171                 write_back=black.WriteBack.YES,
1172                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1173             )
1174             # __BLACK_STDIN_FILENAME__ should have been stripped
1175             report.done.assert_called_with(expected, black.Changed.YES)
1176
1177     @pytest.mark.incompatible_with_mypyc
1178     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1179         with patch(
1180             "black.format_stdin_to_stdout",
1181             return_value=lambda *args, **kwargs: black.Changed.YES,
1182         ) as fsts:
1183             report = MagicMock()
1184             # Even with an existing file, since we are forcing stdin, black
1185             # should output to stdout and not modify the file inplace
1186             p = Path(str(THIS_DIR / "data/collections.py"))
1187             # Make sure is_file actually returns True
1188             self.assertTrue(p.is_file())
1189             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1190             expected = Path(p)
1191             black.reformat_one(
1192                 path,
1193                 fast=True,
1194                 write_back=black.WriteBack.YES,
1195                 mode=DEFAULT_MODE,
1196                 report=report,
1197             )
1198             fsts.assert_called_once()
1199             # __BLACK_STDIN_FILENAME__ should have been stripped
1200             report.done.assert_called_with(expected, black.Changed.YES)
1201
1202     def test_reformat_one_with_stdin_empty(self) -> None:
1203         output = io.StringIO()
1204         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1205             try:
1206                 black.format_stdin_to_stdout(
1207                     fast=True,
1208                     content="",
1209                     write_back=black.WriteBack.YES,
1210                     mode=DEFAULT_MODE,
1211                 )
1212             except io.UnsupportedOperation:
1213                 pass  # StringIO does not support detach
1214             assert output.getvalue() == ""
1215
1216     def test_invalid_cli_regex(self) -> None:
1217         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1218             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1219
1220     def test_required_version_matches_version(self) -> None:
1221         self.invokeBlack(
1222             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1223         )
1224
1225     def test_required_version_does_not_match_version(self) -> None:
1226         self.invokeBlack(
1227             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1228         )
1229
1230     def test_preserves_line_endings(self) -> None:
1231         with TemporaryDirectory() as workspace:
1232             test_file = Path(workspace) / "test.py"
1233             for nl in ["\n", "\r\n"]:
1234                 contents = nl.join(["def f(  ):", "    pass"])
1235                 test_file.write_bytes(contents.encode())
1236                 ff(test_file, write_back=black.WriteBack.YES)
1237                 updated_contents: bytes = test_file.read_bytes()
1238                 self.assertIn(nl.encode(), updated_contents)
1239                 if nl == "\n":
1240                     self.assertNotIn(b"\r\n", updated_contents)
1241
1242     def test_preserves_line_endings_via_stdin(self) -> None:
1243         for nl in ["\n", "\r\n"]:
1244             contents = nl.join(["def f(  ):", "    pass"])
1245             runner = BlackRunner()
1246             result = runner.invoke(
1247                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1248             )
1249             self.assertEqual(result.exit_code, 0)
1250             output = result.stdout_bytes
1251             self.assertIn(nl.encode("utf8"), output)
1252             if nl == "\n":
1253                 self.assertNotIn(b"\r\n", output)
1254
1255     def test_assert_equivalent_different_asts(self) -> None:
1256         with self.assertRaises(AssertionError):
1257             black.assert_equivalent("{}", "None")
1258
1259     def test_shhh_click(self) -> None:
1260         try:
1261             from click import _unicodefun
1262         except ModuleNotFoundError:
1263             self.skipTest("Incompatible Click version")
1264         if not hasattr(_unicodefun, "_verify_python3_env"):
1265             self.skipTest("Incompatible Click version")
1266         # First, let's see if Click is crashing with a preferred ASCII charset.
1267         with patch("locale.getpreferredencoding") as gpe:
1268             gpe.return_value = "ASCII"
1269             with self.assertRaises(RuntimeError):
1270                 _unicodefun._verify_python3_env()  # type: ignore
1271         # Now, let's silence Click...
1272         black.patch_click()
1273         # ...and confirm it's silent.
1274         with patch("locale.getpreferredencoding") as gpe:
1275             gpe.return_value = "ASCII"
1276             try:
1277                 _unicodefun._verify_python3_env()  # type: ignore
1278             except RuntimeError as re:
1279                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1280
1281     def test_root_logger_not_used_directly(self) -> None:
1282         def fail(*args: Any, **kwargs: Any) -> None:
1283             self.fail("Record created with root logger")
1284
1285         with patch.multiple(
1286             logging.root,
1287             debug=fail,
1288             info=fail,
1289             warning=fail,
1290             error=fail,
1291             critical=fail,
1292             log=fail,
1293         ):
1294             ff(THIS_DIR / "util.py")
1295
1296     def test_invalid_config_return_code(self) -> None:
1297         tmp_file = Path(black.dump_to_file())
1298         try:
1299             tmp_config = Path(black.dump_to_file())
1300             tmp_config.unlink()
1301             args = ["--config", str(tmp_config), str(tmp_file)]
1302             self.invokeBlack(args, exit_code=2, ignore_config=False)
1303         finally:
1304             tmp_file.unlink()
1305
1306     def test_parse_pyproject_toml(self) -> None:
1307         test_toml_file = THIS_DIR / "test.toml"
1308         config = black.parse_pyproject_toml(str(test_toml_file))
1309         self.assertEqual(config["verbose"], 1)
1310         self.assertEqual(config["check"], "no")
1311         self.assertEqual(config["diff"], "y")
1312         self.assertEqual(config["color"], True)
1313         self.assertEqual(config["line_length"], 79)
1314         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1315         self.assertEqual(config["exclude"], r"\.pyi?$")
1316         self.assertEqual(config["include"], r"\.py?$")
1317
1318     def test_read_pyproject_toml(self) -> None:
1319         test_toml_file = THIS_DIR / "test.toml"
1320         fake_ctx = FakeContext()
1321         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1322         config = fake_ctx.default_map
1323         self.assertEqual(config["verbose"], "1")
1324         self.assertEqual(config["check"], "no")
1325         self.assertEqual(config["diff"], "y")
1326         self.assertEqual(config["color"], "True")
1327         self.assertEqual(config["line_length"], "79")
1328         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1329         self.assertEqual(config["exclude"], r"\.pyi?$")
1330         self.assertEqual(config["include"], r"\.py?$")
1331
1332     @pytest.mark.incompatible_with_mypyc
1333     def test_find_project_root(self) -> None:
1334         with TemporaryDirectory() as workspace:
1335             root = Path(workspace)
1336             test_dir = root / "test"
1337             test_dir.mkdir()
1338
1339             src_dir = root / "src"
1340             src_dir.mkdir()
1341
1342             root_pyproject = root / "pyproject.toml"
1343             root_pyproject.touch()
1344             src_pyproject = src_dir / "pyproject.toml"
1345             src_pyproject.touch()
1346             src_python = src_dir / "foo.py"
1347             src_python.touch()
1348
1349             self.assertEqual(
1350                 black.find_project_root((src_dir, test_dir)), root.resolve()
1351             )
1352             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1353             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1354
1355     @patch(
1356         "black.files.find_user_pyproject_toml",
1357         black.files.find_user_pyproject_toml.__wrapped__,
1358     )
1359     def test_find_user_pyproject_toml_linux(self) -> None:
1360         if system() == "Windows":
1361             return
1362
1363         # Test if XDG_CONFIG_HOME is checked
1364         with TemporaryDirectory() as workspace:
1365             tmp_user_config = Path(workspace) / "black"
1366             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1367                 self.assertEqual(
1368                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1369                 )
1370
1371         # Test fallback for XDG_CONFIG_HOME
1372         with patch.dict("os.environ"):
1373             os.environ.pop("XDG_CONFIG_HOME", None)
1374             fallback_user_config = Path("~/.config").expanduser() / "black"
1375             self.assertEqual(
1376                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1377             )
1378
1379     def test_find_user_pyproject_toml_windows(self) -> None:
1380         if system() != "Windows":
1381             return
1382
1383         user_config_path = Path.home() / ".black"
1384         self.assertEqual(
1385             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1386         )
1387
1388     def test_bpo_33660_workaround(self) -> None:
1389         if system() == "Windows":
1390             return
1391
1392         # https://bugs.python.org/issue33660
1393         root = Path("/")
1394         with change_directory(root):
1395             path = Path("workspace") / "project"
1396             report = black.Report(verbose=True)
1397             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1398             self.assertEqual(normalized_path, "workspace/project")
1399
1400     def test_newline_comment_interaction(self) -> None:
1401         source = "class A:\\\r\n# type: ignore\n pass\n"
1402         output = black.format_str(source, mode=DEFAULT_MODE)
1403         black.assert_stable(source, output, mode=DEFAULT_MODE)
1404
1405     def test_bpo_2142_workaround(self) -> None:
1406
1407         # https://bugs.python.org/issue2142
1408
1409         source, _ = read_data("missing_final_newline.py")
1410         # read_data adds a trailing newline
1411         source = source.rstrip()
1412         expected, _ = read_data("missing_final_newline.diff")
1413         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1414         diff_header = re.compile(
1415             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1416             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1417         )
1418         try:
1419             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1420             self.assertEqual(result.exit_code, 0)
1421         finally:
1422             os.unlink(tmp_file)
1423         actual = result.output
1424         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1425         self.assertEqual(actual, expected)
1426
1427     @pytest.mark.python2
1428     def test_docstring_reformat_for_py27(self) -> None:
1429         """
1430         Check that stripping trailing whitespace from Python 2 docstrings
1431         doesn't trigger a "not equivalent to source" error
1432         """
1433         source = (
1434             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1435         )
1436         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1437
1438         result = BlackRunner().invoke(
1439             black.main,
1440             ["-", "-q", "--target-version=py27"],
1441             input=BytesIO(source),
1442         )
1443
1444         self.assertEqual(result.exit_code, 0)
1445         actual = result.stdout
1446         self.assertFormatEqual(actual, expected)
1447
1448     @staticmethod
1449     def compare_results(
1450         result: click.testing.Result, expected_value: str, expected_exit_code: int
1451     ) -> None:
1452         """Helper method to test the value and exit code of a click Result."""
1453         assert (
1454             result.output == expected_value
1455         ), "The output did not match the expected value."
1456         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1457
1458     def test_code_option(self) -> None:
1459         """Test the code option with no changes."""
1460         code = 'print("Hello world")\n'
1461         args = ["--code", code]
1462         result = CliRunner().invoke(black.main, args)
1463
1464         self.compare_results(result, code, 0)
1465
1466     def test_code_option_changed(self) -> None:
1467         """Test the code option when changes are required."""
1468         code = "print('hello world')"
1469         formatted = black.format_str(code, mode=DEFAULT_MODE)
1470
1471         args = ["--code", code]
1472         result = CliRunner().invoke(black.main, args)
1473
1474         self.compare_results(result, formatted, 0)
1475
1476     def test_code_option_check(self) -> None:
1477         """Test the code option when check is passed."""
1478         args = ["--check", "--code", 'print("Hello world")\n']
1479         result = CliRunner().invoke(black.main, args)
1480         self.compare_results(result, "", 0)
1481
1482     def test_code_option_check_changed(self) -> None:
1483         """Test the code option when changes are required, and check is passed."""
1484         args = ["--check", "--code", "print('hello world')"]
1485         result = CliRunner().invoke(black.main, args)
1486         self.compare_results(result, "", 1)
1487
1488     def test_code_option_diff(self) -> None:
1489         """Test the code option when diff is passed."""
1490         code = "print('hello world')"
1491         formatted = black.format_str(code, mode=DEFAULT_MODE)
1492         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1493
1494         args = ["--diff", "--code", code]
1495         result = CliRunner().invoke(black.main, args)
1496
1497         # Remove time from diff
1498         output = DIFF_TIME.sub("", result.output)
1499
1500         assert output == result_diff, "The output did not match the expected value."
1501         assert result.exit_code == 0, "The exit code is incorrect."
1502
1503     def test_code_option_color_diff(self) -> None:
1504         """Test the code option when color and diff are passed."""
1505         code = "print('hello world')"
1506         formatted = black.format_str(code, mode=DEFAULT_MODE)
1507
1508         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1509         result_diff = color_diff(result_diff)
1510
1511         args = ["--diff", "--color", "--code", code]
1512         result = CliRunner().invoke(black.main, args)
1513
1514         # Remove time from diff
1515         output = DIFF_TIME.sub("", result.output)
1516
1517         assert output == result_diff, "The output did not match the expected value."
1518         assert result.exit_code == 0, "The exit code is incorrect."
1519
1520     @pytest.mark.incompatible_with_mypyc
1521     def test_code_option_safe(self) -> None:
1522         """Test that the code option throws an error when the sanity checks fail."""
1523         # Patch black.assert_equivalent to ensure the sanity checks fail
1524         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1525             code = 'print("Hello world")'
1526             error_msg = f"{code}\nerror: cannot format <string>: \n"
1527
1528             args = ["--safe", "--code", code]
1529             result = CliRunner().invoke(black.main, args)
1530
1531             self.compare_results(result, error_msg, 123)
1532
1533     def test_code_option_fast(self) -> None:
1534         """Test that the code option ignores errors when the sanity checks fail."""
1535         # Patch black.assert_equivalent to ensure the sanity checks fail
1536         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1537             code = 'print("Hello world")'
1538             formatted = black.format_str(code, mode=DEFAULT_MODE)
1539
1540             args = ["--fast", "--code", code]
1541             result = CliRunner().invoke(black.main, args)
1542
1543             self.compare_results(result, formatted, 0)
1544
1545     @pytest.mark.incompatible_with_mypyc
1546     def test_code_option_config(self) -> None:
1547         """
1548         Test that the code option finds the pyproject.toml in the current directory.
1549         """
1550         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1551             args = ["--code", "print"]
1552             # This is the only directory known to contain a pyproject.toml
1553             with change_directory(PROJECT_ROOT):
1554                 CliRunner().invoke(black.main, args)
1555                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1556
1557             assert (
1558                 len(parse.mock_calls) >= 1
1559             ), "Expected config parse to be called with the current directory."
1560
1561             _, call_args, _ = parse.mock_calls[0]
1562             assert (
1563                 call_args[0].lower() == str(pyproject_path).lower()
1564             ), "Incorrect config loaded."
1565
1566     @pytest.mark.incompatible_with_mypyc
1567     def test_code_option_parent_config(self) -> None:
1568         """
1569         Test that the code option finds the pyproject.toml in the parent directory.
1570         """
1571         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1572             with change_directory(THIS_DIR):
1573                 args = ["--code", "print"]
1574                 CliRunner().invoke(black.main, args)
1575
1576                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1577                 assert (
1578                     len(parse.mock_calls) >= 1
1579                 ), "Expected config parse to be called with the current directory."
1580
1581                 _, call_args, _ = parse.mock_calls[0]
1582                 assert (
1583                     call_args[0].lower() == str(pyproject_path).lower()
1584                 ), "Incorrect config loaded."
1585
1586     def test_for_handled_unexpected_eof_error(self) -> None:
1587         """
1588         Test that an unexpected EOF SyntaxError is nicely presented.
1589         """
1590         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1591             black.lib2to3_parse("print(", {})
1592
1593         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1594
1595     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1596         with pytest.raises(AssertionError) as err:
1597             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1598
1599         err.match("--safe")
1600         # Unfortunately the SyntaxError message has changed in newer versions so we
1601         # can't match it directly.
1602         err.match("invalid character")
1603         err.match(r"\(<unknown>, line 1\)")
1604
1605
1606 class TestCaching:
1607     def test_cache_broken_file(self) -> None:
1608         mode = DEFAULT_MODE
1609         with cache_dir() as workspace:
1610             cache_file = get_cache_file(mode)
1611             cache_file.write_text("this is not a pickle")
1612             assert black.read_cache(mode) == {}
1613             src = (workspace / "test.py").resolve()
1614             src.write_text("print('hello')")
1615             invokeBlack([str(src)])
1616             cache = black.read_cache(mode)
1617             assert str(src) in cache
1618
1619     def test_cache_single_file_already_cached(self) -> None:
1620         mode = DEFAULT_MODE
1621         with cache_dir() as workspace:
1622             src = (workspace / "test.py").resolve()
1623             src.write_text("print('hello')")
1624             black.write_cache({}, [src], mode)
1625             invokeBlack([str(src)])
1626             assert src.read_text() == "print('hello')"
1627
1628     @event_loop()
1629     def test_cache_multiple_files(self) -> None:
1630         mode = DEFAULT_MODE
1631         with cache_dir() as workspace, patch(
1632             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1633         ):
1634             one = (workspace / "one.py").resolve()
1635             with one.open("w") as fobj:
1636                 fobj.write("print('hello')")
1637             two = (workspace / "two.py").resolve()
1638             with two.open("w") as fobj:
1639                 fobj.write("print('hello')")
1640             black.write_cache({}, [one], mode)
1641             invokeBlack([str(workspace)])
1642             with one.open("r") as fobj:
1643                 assert fobj.read() == "print('hello')"
1644             with two.open("r") as fobj:
1645                 assert fobj.read() == 'print("hello")\n'
1646             cache = black.read_cache(mode)
1647             assert str(one) in cache
1648             assert str(two) in cache
1649
1650     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1651     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1652         mode = DEFAULT_MODE
1653         with cache_dir() as workspace:
1654             src = (workspace / "test.py").resolve()
1655             with src.open("w") as fobj:
1656                 fobj.write("print('hello')")
1657             with patch("black.read_cache") as read_cache, patch(
1658                 "black.write_cache"
1659             ) as write_cache:
1660                 cmd = [str(src), "--diff"]
1661                 if color:
1662                     cmd.append("--color")
1663                 invokeBlack(cmd)
1664                 cache_file = get_cache_file(mode)
1665                 assert cache_file.exists() is False
1666                 write_cache.assert_not_called()
1667                 read_cache.assert_not_called()
1668
1669     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1670     @event_loop()
1671     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1672         with cache_dir() as workspace:
1673             for tag in range(0, 4):
1674                 src = (workspace / f"test{tag}.py").resolve()
1675                 with src.open("w") as fobj:
1676                     fobj.write("print('hello')")
1677             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1678                 cmd = ["--diff", str(workspace)]
1679                 if color:
1680                     cmd.append("--color")
1681                 invokeBlack(cmd, exit_code=0)
1682                 # this isn't quite doing what we want, but if it _isn't_
1683                 # called then we cannot be using the lock it provides
1684                 mgr.assert_called()
1685
1686     def test_no_cache_when_stdin(self) -> None:
1687         mode = DEFAULT_MODE
1688         with cache_dir():
1689             result = CliRunner().invoke(
1690                 black.main, ["-"], input=BytesIO(b"print('hello')")
1691             )
1692             assert not result.exit_code
1693             cache_file = get_cache_file(mode)
1694             assert not cache_file.exists()
1695
1696     def test_read_cache_no_cachefile(self) -> None:
1697         mode = DEFAULT_MODE
1698         with cache_dir():
1699             assert black.read_cache(mode) == {}
1700
1701     def test_write_cache_read_cache(self) -> None:
1702         mode = DEFAULT_MODE
1703         with cache_dir() as workspace:
1704             src = (workspace / "test.py").resolve()
1705             src.touch()
1706             black.write_cache({}, [src], mode)
1707             cache = black.read_cache(mode)
1708             assert str(src) in cache
1709             assert cache[str(src)] == black.get_cache_info(src)
1710
1711     def test_filter_cached(self) -> None:
1712         with TemporaryDirectory() as workspace:
1713             path = Path(workspace)
1714             uncached = (path / "uncached").resolve()
1715             cached = (path / "cached").resolve()
1716             cached_but_changed = (path / "changed").resolve()
1717             uncached.touch()
1718             cached.touch()
1719             cached_but_changed.touch()
1720             cache = {
1721                 str(cached): black.get_cache_info(cached),
1722                 str(cached_but_changed): (0.0, 0),
1723             }
1724             todo, done = black.filter_cached(
1725                 cache, {uncached, cached, cached_but_changed}
1726             )
1727             assert todo == {uncached, cached_but_changed}
1728             assert done == {cached}
1729
1730     def test_write_cache_creates_directory_if_needed(self) -> None:
1731         mode = DEFAULT_MODE
1732         with cache_dir(exists=False) as workspace:
1733             assert not workspace.exists()
1734             black.write_cache({}, [], mode)
1735             assert workspace.exists()
1736
1737     @event_loop()
1738     def test_failed_formatting_does_not_get_cached(self) -> None:
1739         mode = DEFAULT_MODE
1740         with cache_dir() as workspace, patch(
1741             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1742         ):
1743             failing = (workspace / "failing.py").resolve()
1744             with failing.open("w") as fobj:
1745                 fobj.write("not actually python")
1746             clean = (workspace / "clean.py").resolve()
1747             with clean.open("w") as fobj:
1748                 fobj.write('print("hello")\n')
1749             invokeBlack([str(workspace)], exit_code=123)
1750             cache = black.read_cache(mode)
1751             assert str(failing) not in cache
1752             assert str(clean) in cache
1753
1754     def test_write_cache_write_fail(self) -> None:
1755         mode = DEFAULT_MODE
1756         with cache_dir(), patch.object(Path, "open") as mock:
1757             mock.side_effect = OSError
1758             black.write_cache({}, [], mode)
1759
1760     def test_read_cache_line_lengths(self) -> None:
1761         mode = DEFAULT_MODE
1762         short_mode = replace(DEFAULT_MODE, line_length=1)
1763         with cache_dir() as workspace:
1764             path = (workspace / "file.py").resolve()
1765             path.touch()
1766             black.write_cache({}, [path], mode)
1767             one = black.read_cache(mode)
1768             assert str(path) in one
1769             two = black.read_cache(short_mode)
1770             assert str(path) not in two
1771
1772
1773 def assert_collected_sources(
1774     src: Sequence[Union[str, Path]],
1775     expected: Sequence[Union[str, Path]],
1776     *,
1777     exclude: Optional[str] = None,
1778     include: Optional[str] = None,
1779     extend_exclude: Optional[str] = None,
1780     force_exclude: Optional[str] = None,
1781     stdin_filename: Optional[str] = None,
1782 ) -> None:
1783     gs_src = tuple(str(Path(s)) for s in src)
1784     gs_expected = [Path(s) for s in expected]
1785     gs_exclude = None if exclude is None else compile_pattern(exclude)
1786     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1787     gs_extend_exclude = (
1788         None if extend_exclude is None else compile_pattern(extend_exclude)
1789     )
1790     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1791     collected = black.get_sources(
1792         ctx=FakeContext(),
1793         src=gs_src,
1794         quiet=False,
1795         verbose=False,
1796         include=gs_include,
1797         exclude=gs_exclude,
1798         extend_exclude=gs_extend_exclude,
1799         force_exclude=gs_force_exclude,
1800         report=black.Report(),
1801         stdin_filename=stdin_filename,
1802     )
1803     assert sorted(collected) == sorted(gs_expected)
1804
1805
1806 class TestFileCollection:
1807     def test_include_exclude(self) -> None:
1808         path = THIS_DIR / "data" / "include_exclude_tests"
1809         src = [path]
1810         expected = [
1811             Path(path / "b/dont_exclude/a.py"),
1812             Path(path / "b/dont_exclude/a.pyi"),
1813         ]
1814         assert_collected_sources(
1815             src,
1816             expected,
1817             include=r"\.pyi?$",
1818             exclude=r"/exclude/|/\.definitely_exclude/",
1819         )
1820
1821     def test_gitignore_used_as_default(self) -> None:
1822         base = Path(DATA_DIR / "include_exclude_tests")
1823         expected = [
1824             base / "b/.definitely_exclude/a.py",
1825             base / "b/.definitely_exclude/a.pyi",
1826         ]
1827         src = [base / "b/"]
1828         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1829
1830     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1831     def test_exclude_for_issue_1572(self) -> None:
1832         # Exclude shouldn't touch files that were explicitly given to Black through the
1833         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1834         # https://github.com/psf/black/issues/1572
1835         path = DATA_DIR / "include_exclude_tests"
1836         src = [path / "b/exclude/a.py"]
1837         expected = [path / "b/exclude/a.py"]
1838         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1839
1840     def test_gitignore_exclude(self) -> None:
1841         path = THIS_DIR / "data" / "include_exclude_tests"
1842         include = re.compile(r"\.pyi?$")
1843         exclude = re.compile(r"")
1844         report = black.Report()
1845         gitignore = PathSpec.from_lines(
1846             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1847         )
1848         sources: List[Path] = []
1849         expected = [
1850             Path(path / "b/dont_exclude/a.py"),
1851             Path(path / "b/dont_exclude/a.pyi"),
1852         ]
1853         this_abs = THIS_DIR.resolve()
1854         sources.extend(
1855             black.gen_python_files(
1856                 path.iterdir(),
1857                 this_abs,
1858                 include,
1859                 exclude,
1860                 None,
1861                 None,
1862                 report,
1863                 gitignore,
1864                 verbose=False,
1865                 quiet=False,
1866             )
1867         )
1868         assert sorted(expected) == sorted(sources)
1869
1870     def test_nested_gitignore(self) -> None:
1871         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1872         include = re.compile(r"\.pyi?$")
1873         exclude = re.compile(r"")
1874         root_gitignore = black.files.get_gitignore(path)
1875         report = black.Report()
1876         expected: List[Path] = [
1877             Path(path / "x.py"),
1878             Path(path / "root/b.py"),
1879             Path(path / "root/c.py"),
1880             Path(path / "root/child/c.py"),
1881         ]
1882         this_abs = THIS_DIR.resolve()
1883         sources = list(
1884             black.gen_python_files(
1885                 path.iterdir(),
1886                 this_abs,
1887                 include,
1888                 exclude,
1889                 None,
1890                 None,
1891                 report,
1892                 root_gitignore,
1893                 verbose=False,
1894                 quiet=False,
1895             )
1896         )
1897         assert sorted(expected) == sorted(sources)
1898
1899     def test_invalid_gitignore(self) -> None:
1900         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1901         empty_config = path / "pyproject.toml"
1902         result = BlackRunner().invoke(
1903             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1904         )
1905         assert result.exit_code == 1
1906         assert result.stderr_bytes is not None
1907
1908         gitignore = path / ".gitignore"
1909         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1910
1911     def test_invalid_nested_gitignore(self) -> None:
1912         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1913         empty_config = path / "pyproject.toml"
1914         result = BlackRunner().invoke(
1915             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1916         )
1917         assert result.exit_code == 1
1918         assert result.stderr_bytes is not None
1919
1920         gitignore = path / "a" / ".gitignore"
1921         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1922
1923     def test_empty_include(self) -> None:
1924         path = DATA_DIR / "include_exclude_tests"
1925         src = [path]
1926         expected = [
1927             Path(path / "b/exclude/a.pie"),
1928             Path(path / "b/exclude/a.py"),
1929             Path(path / "b/exclude/a.pyi"),
1930             Path(path / "b/dont_exclude/a.pie"),
1931             Path(path / "b/dont_exclude/a.py"),
1932             Path(path / "b/dont_exclude/a.pyi"),
1933             Path(path / "b/.definitely_exclude/a.pie"),
1934             Path(path / "b/.definitely_exclude/a.py"),
1935             Path(path / "b/.definitely_exclude/a.pyi"),
1936             Path(path / ".gitignore"),
1937             Path(path / "pyproject.toml"),
1938         ]
1939         # Setting exclude explicitly to an empty string to block .gitignore usage.
1940         assert_collected_sources(src, expected, include="", exclude="")
1941
1942     def test_extend_exclude(self) -> None:
1943         path = DATA_DIR / "include_exclude_tests"
1944         src = [path]
1945         expected = [
1946             Path(path / "b/exclude/a.py"),
1947             Path(path / "b/dont_exclude/a.py"),
1948         ]
1949         assert_collected_sources(
1950             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1951         )
1952
1953     @pytest.mark.incompatible_with_mypyc
1954     def test_symlink_out_of_root_directory(self) -> None:
1955         path = MagicMock()
1956         root = THIS_DIR.resolve()
1957         child = MagicMock()
1958         include = re.compile(black.DEFAULT_INCLUDES)
1959         exclude = re.compile(black.DEFAULT_EXCLUDES)
1960         report = black.Report()
1961         gitignore = PathSpec.from_lines("gitwildmatch", [])
1962         # `child` should behave like a symlink which resolved path is clearly
1963         # outside of the `root` directory.
1964         path.iterdir.return_value = [child]
1965         child.resolve.return_value = Path("/a/b/c")
1966         child.as_posix.return_value = "/a/b/c"
1967         child.is_symlink.return_value = True
1968         try:
1969             list(
1970                 black.gen_python_files(
1971                     path.iterdir(),
1972                     root,
1973                     include,
1974                     exclude,
1975                     None,
1976                     None,
1977                     report,
1978                     gitignore,
1979                     verbose=False,
1980                     quiet=False,
1981                 )
1982             )
1983         except ValueError as ve:
1984             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1985         path.iterdir.assert_called_once()
1986         child.resolve.assert_called_once()
1987         child.is_symlink.assert_called_once()
1988         # `child` should behave like a strange file which resolved path is clearly
1989         # outside of the `root` directory.
1990         child.is_symlink.return_value = False
1991         with pytest.raises(ValueError):
1992             list(
1993                 black.gen_python_files(
1994                     path.iterdir(),
1995                     root,
1996                     include,
1997                     exclude,
1998                     None,
1999                     None,
2000                     report,
2001                     gitignore,
2002                     verbose=False,
2003                     quiet=False,
2004                 )
2005             )
2006         path.iterdir.assert_called()
2007         assert path.iterdir.call_count == 2
2008         child.resolve.assert_called()
2009         assert child.resolve.call_count == 2
2010         child.is_symlink.assert_called()
2011         assert child.is_symlink.call_count == 2
2012
2013     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2014     def test_get_sources_with_stdin(self) -> None:
2015         src = ["-"]
2016         expected = ["-"]
2017         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2018
2019     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2020     def test_get_sources_with_stdin_filename(self) -> None:
2021         src = ["-"]
2022         stdin_filename = str(THIS_DIR / "data/collections.py")
2023         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2024         assert_collected_sources(
2025             src,
2026             expected,
2027             exclude=r"/exclude/a\.py",
2028             stdin_filename=stdin_filename,
2029         )
2030
2031     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2032     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2033         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2034         # file being passed directly. This is the same as
2035         # test_exclude_for_issue_1572
2036         path = DATA_DIR / "include_exclude_tests"
2037         src = ["-"]
2038         stdin_filename = str(path / "b/exclude/a.py")
2039         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2040         assert_collected_sources(
2041             src,
2042             expected,
2043             exclude=r"/exclude/|a\.py",
2044             stdin_filename=stdin_filename,
2045         )
2046
2047     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2048     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2049         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2050         # file being passed directly. This is the same as
2051         # test_exclude_for_issue_1572
2052         src = ["-"]
2053         path = THIS_DIR / "data" / "include_exclude_tests"
2054         stdin_filename = str(path / "b/exclude/a.py")
2055         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2056         assert_collected_sources(
2057             src,
2058             expected,
2059             extend_exclude=r"/exclude/|a\.py",
2060             stdin_filename=stdin_filename,
2061         )
2062
2063     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2064     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2065         # Force exclude should exclude the file when passing it through
2066         # stdin_filename
2067         path = THIS_DIR / "data" / "include_exclude_tests"
2068         stdin_filename = str(path / "b/exclude/a.py")
2069         assert_collected_sources(
2070             src=["-"],
2071             expected=[],
2072             force_exclude=r"/exclude/|a\.py",
2073             stdin_filename=stdin_filename,
2074         )
2075
2076
2077 @pytest.mark.python2
2078 @pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
2079 def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
2080     args = [
2081         "--config",
2082         str(THIS_DIR / "empty.toml"),
2083         str(DATA_DIR / "python2.py"),
2084         "--check",
2085     ]
2086     if explicit:
2087         args.append("--target-version=py27")
2088     with cache_dir():
2089         result = BlackRunner().invoke(black.main, args)
2090     assert "DEPRECATION: Python 2 support will be removed" in result.stderr
2091
2092
2093 @pytest.mark.python2
2094 def test_python_2_deprecation_autodetection_extended() -> None:
2095     # this test has a similar construction to test_get_features_used_decorator
2096     python2, non_python2 = read_data("python2_detection")
2097     for python2_case in python2.split("###"):
2098         node = black.lib2to3_parse(python2_case)
2099         assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
2100     for non_python2_case in non_python2.split("###"):
2101         node = black.lib2to3_parse(non_python2_case)
2102         assert black.detect_target_versions(node) != {
2103             TargetVersion.PY27
2104         }, non_python2_case
2105
2106
2107 try:
2108     with open(black.__file__, "r", encoding="utf-8") as _bf:
2109         black_source_lines = _bf.readlines()
2110 except UnicodeDecodeError:
2111     if not black.COMPILED:
2112         raise
2113
2114
2115 def tracefunc(
2116     frame: types.FrameType, event: str, arg: Any
2117 ) -> Callable[[types.FrameType, str, Any], Any]:
2118     """Show function calls `from black/__init__.py` as they happen.
2119
2120     Register this with `sys.settrace()` in a test you're debugging.
2121     """
2122     if event != "call":
2123         return tracefunc
2124
2125     stack = len(inspect.stack()) - 19
2126     stack *= 2
2127     filename = frame.f_code.co_filename
2128     lineno = frame.f_lineno
2129     func_sig_lineno = lineno - 1
2130     funcname = black_source_lines[func_sig_lineno].strip()
2131     while funcname.startswith("@"):
2132         func_sig_lineno += 1
2133         funcname = black_source_lines[func_sig_lineno].strip()
2134     if "black/__init__.py" in filename:
2135         print(f"{' ' * stack}{lineno}:{funcname}")
2136     return tracefunc