]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Drop upper version bounds on dependencies (GH-2718)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args, catch_exceptions=False)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY27})
728         black.lib2to3_parse(straddling, {TargetVersion.PY36})
729         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
730
731         py2_only = "print x"
732         black.lib2to3_parse(py2_only)
733         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
734         with self.assertRaises(black.InvalidInput):
735             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
736         with self.assertRaises(black.InvalidInput):
737             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
738
739         py3_only = "exec(x, end=y)"
740         black.lib2to3_parse(py3_only)
741         with self.assertRaises(black.InvalidInput):
742             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
743         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
744         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
745
746     def test_get_features_used_decorator(self) -> None:
747         # Test the feature detection of new decorator syntax
748         # since this makes some test cases of test_get_features_used()
749         # fails if it fails, this is tested first so that a useful case
750         # is identified
751         simples, relaxed = read_data("decorators")
752         # skip explanation comments at the top of the file
753         for simple_test in simples.split("##")[1:]:
754             node = black.lib2to3_parse(simple_test)
755             decorator = str(node.children[0].children[0]).strip()
756             self.assertNotIn(
757                 Feature.RELAXED_DECORATORS,
758                 black.get_features_used(node),
759                 msg=(
760                     f"decorator '{decorator}' follows python<=3.8 syntax"
761                     "but is detected as 3.9+"
762                     # f"The full node is\n{node!r}"
763                 ),
764             )
765         # skip the '# output' comment at the top of the output part
766         for relaxed_test in relaxed.split("##")[1:]:
767             node = black.lib2to3_parse(relaxed_test)
768             decorator = str(node.children[0].children[0]).strip()
769             self.assertIn(
770                 Feature.RELAXED_DECORATORS,
771                 black.get_features_used(node),
772                 msg=(
773                     f"decorator '{decorator}' uses python3.9+ syntax"
774                     "but is detected as python<=3.8"
775                     # f"The full node is\n{node!r}"
776                 ),
777             )
778
779     def test_get_features_used(self) -> None:
780         node = black.lib2to3_parse("def f(*, arg): ...\n")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("def f(*, arg,): ...\n")
783         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
784         node = black.lib2to3_parse("f(*arg,)\n")
785         self.assertEqual(
786             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
787         )
788         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
789         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
790         node = black.lib2to3_parse("123_456\n")
791         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
792         node = black.lib2to3_parse("123456\n")
793         self.assertEqual(black.get_features_used(node), set())
794         source, expected = read_data("function")
795         node = black.lib2to3_parse(source)
796         expected_features = {
797             Feature.TRAILING_COMMA_IN_CALL,
798             Feature.TRAILING_COMMA_IN_DEF,
799             Feature.F_STRINGS,
800         }
801         self.assertEqual(black.get_features_used(node), expected_features)
802         node = black.lib2to3_parse(expected)
803         self.assertEqual(black.get_features_used(node), expected_features)
804         source, expected = read_data("expression")
805         node = black.lib2to3_parse(source)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse(expected)
808         self.assertEqual(black.get_features_used(node), set())
809         node = black.lib2to3_parse("lambda a, /, b: ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(a, /, b): ...")
812         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
813         node = black.lib2to3_parse("def fn(): yield a, b")
814         self.assertEqual(black.get_features_used(node), set())
815         node = black.lib2to3_parse("def fn(): return a, b")
816         self.assertEqual(black.get_features_used(node), set())
817         node = black.lib2to3_parse("def fn(): yield *b, c")
818         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
819         node = black.lib2to3_parse("def fn(): return a, *b, c")
820         self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
821         node = black.lib2to3_parse("x = a, *b, c")
822         self.assertEqual(black.get_features_used(node), set())
823         node = black.lib2to3_parse("x: Any = regular")
824         self.assertEqual(black.get_features_used(node), set())
825         node = black.lib2to3_parse("x: Any = (regular, regular)")
826         self.assertEqual(black.get_features_used(node), set())
827         node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
828         self.assertEqual(black.get_features_used(node), set())
829         node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
830         self.assertEqual(
831             black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
832         )
833
834     def test_get_features_used_for_future_flags(self) -> None:
835         for src, features in [
836             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
837             (
838                 "from __future__ import (other, annotations)",
839                 {Feature.FUTURE_ANNOTATIONS},
840             ),
841             ("a = 1 + 2\nfrom something import annotations", set()),
842             ("from __future__ import x, y", set()),
843         ]:
844             with self.subTest(src=src, features=features):
845                 node = black.lib2to3_parse(src)
846                 future_imports = black.get_future_imports(node)
847                 self.assertEqual(
848                     black.get_features_used(node, future_imports=future_imports),
849                     features,
850                 )
851
852     def test_get_future_imports(self) -> None:
853         node = black.lib2to3_parse("\n")
854         self.assertEqual(set(), black.get_future_imports(node))
855         node = black.lib2to3_parse("from __future__ import black\n")
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
858         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
859         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
860         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import multiple\nfrom __future__ import imports\n"
863         )
864         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
865         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
866         self.assertEqual({"black"}, black.get_future_imports(node))
867         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
868         self.assertEqual({"black"}, black.get_future_imports(node))
869         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
870         self.assertEqual(set(), black.get_future_imports(node))
871         node = black.lib2to3_parse("from some.module import black\n")
872         self.assertEqual(set(), black.get_future_imports(node))
873         node = black.lib2to3_parse(
874             "from __future__ import unicode_literals as _unicode_literals"
875         )
876         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
877         node = black.lib2to3_parse(
878             "from __future__ import unicode_literals as _lol, print"
879         )
880         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
881
882     @pytest.mark.incompatible_with_mypyc
883     def test_debug_visitor(self) -> None:
884         source, _ = read_data("debug_visitor.py")
885         expected, _ = read_data("debug_visitor.out")
886         out_lines = []
887         err_lines = []
888
889         def out(msg: str, **kwargs: Any) -> None:
890             out_lines.append(msg)
891
892         def err(msg: str, **kwargs: Any) -> None:
893             err_lines.append(msg)
894
895         with patch("black.debug.out", out):
896             DebugVisitor.show(source)
897         actual = "\n".join(out_lines) + "\n"
898         log_name = ""
899         if expected != actual:
900             log_name = black.dump_to_file(*out_lines)
901         self.assertEqual(
902             expected,
903             actual,
904             f"AST print out is different. Actual version dumped to {log_name}",
905         )
906
907     def test_format_file_contents(self) -> None:
908         empty = ""
909         mode = DEFAULT_MODE
910         with self.assertRaises(black.NothingChanged):
911             black.format_file_contents(empty, mode=mode, fast=False)
912         just_nl = "\n"
913         with self.assertRaises(black.NothingChanged):
914             black.format_file_contents(just_nl, mode=mode, fast=False)
915         same = "j = [1, 2, 3]\n"
916         with self.assertRaises(black.NothingChanged):
917             black.format_file_contents(same, mode=mode, fast=False)
918         different = "j = [1,2,3]"
919         expected = same
920         actual = black.format_file_contents(different, mode=mode, fast=False)
921         self.assertEqual(expected, actual)
922         invalid = "return if you can"
923         with self.assertRaises(black.InvalidInput) as e:
924             black.format_file_contents(invalid, mode=mode, fast=False)
925         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
926
927     def test_endmarker(self) -> None:
928         n = black.lib2to3_parse("\n")
929         self.assertEqual(n.type, black.syms.file_input)
930         self.assertEqual(len(n.children), 1)
931         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
932
933     @pytest.mark.incompatible_with_mypyc
934     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
935     def test_assertFormatEqual(self) -> None:
936         out_lines = []
937         err_lines = []
938
939         def out(msg: str, **kwargs: Any) -> None:
940             out_lines.append(msg)
941
942         def err(msg: str, **kwargs: Any) -> None:
943             err_lines.append(msg)
944
945         with patch("black.output._out", out), patch("black.output._err", err):
946             with self.assertRaises(AssertionError):
947                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
948
949         out_str = "".join(out_lines)
950         self.assertTrue("Expected tree:" in out_str)
951         self.assertTrue("Actual tree:" in out_str)
952         self.assertEqual("".join(err_lines), "")
953
954     @event_loop()
955     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
956     def test_works_in_mono_process_only_environment(self) -> None:
957         with cache_dir() as workspace:
958             for f in [
959                 (workspace / "one.py").resolve(),
960                 (workspace / "two.py").resolve(),
961             ]:
962                 f.write_text('print("hello")\n')
963             self.invokeBlack([str(workspace)])
964
965     @event_loop()
966     def test_check_diff_use_together(self) -> None:
967         with cache_dir():
968             # Files which will be reformatted.
969             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
970             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
971             # Files which will not be reformatted.
972             src2 = (THIS_DIR / "data" / "composition.py").resolve()
973             self.invokeBlack([str(src2), "--diff", "--check"])
974             # Multi file command.
975             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
976
977     def test_no_files(self) -> None:
978         with cache_dir():
979             # Without an argument, black exits with error code 0.
980             self.invokeBlack([])
981
982     def test_broken_symlink(self) -> None:
983         with cache_dir() as workspace:
984             symlink = workspace / "broken_link.py"
985             try:
986                 symlink.symlink_to("nonexistent.py")
987             except (OSError, NotImplementedError) as e:
988                 self.skipTest(f"Can't create symlinks: {e}")
989             self.invokeBlack([str(workspace.resolve())])
990
991     def test_single_file_force_pyi(self) -> None:
992         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
993         contents, expected = read_data("force_pyi")
994         with cache_dir() as workspace:
995             path = (workspace / "file.py").resolve()
996             with open(path, "w") as fh:
997                 fh.write(contents)
998             self.invokeBlack([str(path), "--pyi"])
999             with open(path, "r") as fh:
1000                 actual = fh.read()
1001             # verify cache with --pyi is separate
1002             pyi_cache = black.read_cache(pyi_mode)
1003             self.assertIn(str(path), pyi_cache)
1004             normal_cache = black.read_cache(DEFAULT_MODE)
1005             self.assertNotIn(str(path), normal_cache)
1006         self.assertFormatEqual(expected, actual)
1007         black.assert_equivalent(contents, actual)
1008         black.assert_stable(contents, actual, pyi_mode)
1009
1010     @event_loop()
1011     def test_multi_file_force_pyi(self) -> None:
1012         reg_mode = DEFAULT_MODE
1013         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1014         contents, expected = read_data("force_pyi")
1015         with cache_dir() as workspace:
1016             paths = [
1017                 (workspace / "file1.py").resolve(),
1018                 (workspace / "file2.py").resolve(),
1019             ]
1020             for path in paths:
1021                 with open(path, "w") as fh:
1022                     fh.write(contents)
1023             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1024             for path in paths:
1025                 with open(path, "r") as fh:
1026                     actual = fh.read()
1027                 self.assertEqual(actual, expected)
1028             # verify cache with --pyi is separate
1029             pyi_cache = black.read_cache(pyi_mode)
1030             normal_cache = black.read_cache(reg_mode)
1031             for path in paths:
1032                 self.assertIn(str(path), pyi_cache)
1033                 self.assertNotIn(str(path), normal_cache)
1034
1035     def test_pipe_force_pyi(self) -> None:
1036         source, expected = read_data("force_pyi")
1037         result = CliRunner().invoke(
1038             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1039         )
1040         self.assertEqual(result.exit_code, 0)
1041         actual = result.output
1042         self.assertFormatEqual(actual, expected)
1043
1044     def test_single_file_force_py36(self) -> None:
1045         reg_mode = DEFAULT_MODE
1046         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1047         source, expected = read_data("force_py36")
1048         with cache_dir() as workspace:
1049             path = (workspace / "file.py").resolve()
1050             with open(path, "w") as fh:
1051                 fh.write(source)
1052             self.invokeBlack([str(path), *PY36_ARGS])
1053             with open(path, "r") as fh:
1054                 actual = fh.read()
1055             # verify cache with --target-version is separate
1056             py36_cache = black.read_cache(py36_mode)
1057             self.assertIn(str(path), py36_cache)
1058             normal_cache = black.read_cache(reg_mode)
1059             self.assertNotIn(str(path), normal_cache)
1060         self.assertEqual(actual, expected)
1061
1062     @event_loop()
1063     def test_multi_file_force_py36(self) -> None:
1064         reg_mode = DEFAULT_MODE
1065         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1066         source, expected = read_data("force_py36")
1067         with cache_dir() as workspace:
1068             paths = [
1069                 (workspace / "file1.py").resolve(),
1070                 (workspace / "file2.py").resolve(),
1071             ]
1072             for path in paths:
1073                 with open(path, "w") as fh:
1074                     fh.write(source)
1075             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1076             for path in paths:
1077                 with open(path, "r") as fh:
1078                     actual = fh.read()
1079                 self.assertEqual(actual, expected)
1080             # verify cache with --target-version is separate
1081             pyi_cache = black.read_cache(py36_mode)
1082             normal_cache = black.read_cache(reg_mode)
1083             for path in paths:
1084                 self.assertIn(str(path), pyi_cache)
1085                 self.assertNotIn(str(path), normal_cache)
1086
1087     def test_pipe_force_py36(self) -> None:
1088         source, expected = read_data("force_py36")
1089         result = CliRunner().invoke(
1090             black.main,
1091             ["-", "-q", "--target-version=py36"],
1092             input=BytesIO(source.encode("utf8")),
1093         )
1094         self.assertEqual(result.exit_code, 0)
1095         actual = result.output
1096         self.assertFormatEqual(actual, expected)
1097
1098     @pytest.mark.incompatible_with_mypyc
1099     def test_reformat_one_with_stdin(self) -> None:
1100         with patch(
1101             "black.format_stdin_to_stdout",
1102             return_value=lambda *args, **kwargs: black.Changed.YES,
1103         ) as fsts:
1104             report = MagicMock()
1105             path = Path("-")
1106             black.reformat_one(
1107                 path,
1108                 fast=True,
1109                 write_back=black.WriteBack.YES,
1110                 mode=DEFAULT_MODE,
1111                 report=report,
1112             )
1113             fsts.assert_called_once()
1114             report.done.assert_called_with(path, black.Changed.YES)
1115
1116     @pytest.mark.incompatible_with_mypyc
1117     def test_reformat_one_with_stdin_filename(self) -> None:
1118         with patch(
1119             "black.format_stdin_to_stdout",
1120             return_value=lambda *args, **kwargs: black.Changed.YES,
1121         ) as fsts:
1122             report = MagicMock()
1123             p = "foo.py"
1124             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1125             expected = Path(p)
1126             black.reformat_one(
1127                 path,
1128                 fast=True,
1129                 write_back=black.WriteBack.YES,
1130                 mode=DEFAULT_MODE,
1131                 report=report,
1132             )
1133             fsts.assert_called_once_with(
1134                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1135             )
1136             # __BLACK_STDIN_FILENAME__ should have been stripped
1137             report.done.assert_called_with(expected, black.Changed.YES)
1138
1139     @pytest.mark.incompatible_with_mypyc
1140     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1141         with patch(
1142             "black.format_stdin_to_stdout",
1143             return_value=lambda *args, **kwargs: black.Changed.YES,
1144         ) as fsts:
1145             report = MagicMock()
1146             p = "foo.pyi"
1147             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1148             expected = Path(p)
1149             black.reformat_one(
1150                 path,
1151                 fast=True,
1152                 write_back=black.WriteBack.YES,
1153                 mode=DEFAULT_MODE,
1154                 report=report,
1155             )
1156             fsts.assert_called_once_with(
1157                 fast=True,
1158                 write_back=black.WriteBack.YES,
1159                 mode=replace(DEFAULT_MODE, is_pyi=True),
1160             )
1161             # __BLACK_STDIN_FILENAME__ should have been stripped
1162             report.done.assert_called_with(expected, black.Changed.YES)
1163
1164     @pytest.mark.incompatible_with_mypyc
1165     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1166         with patch(
1167             "black.format_stdin_to_stdout",
1168             return_value=lambda *args, **kwargs: black.Changed.YES,
1169         ) as fsts:
1170             report = MagicMock()
1171             p = "foo.ipynb"
1172             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1173             expected = Path(p)
1174             black.reformat_one(
1175                 path,
1176                 fast=True,
1177                 write_back=black.WriteBack.YES,
1178                 mode=DEFAULT_MODE,
1179                 report=report,
1180             )
1181             fsts.assert_called_once_with(
1182                 fast=True,
1183                 write_back=black.WriteBack.YES,
1184                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1185             )
1186             # __BLACK_STDIN_FILENAME__ should have been stripped
1187             report.done.assert_called_with(expected, black.Changed.YES)
1188
1189     @pytest.mark.incompatible_with_mypyc
1190     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1191         with patch(
1192             "black.format_stdin_to_stdout",
1193             return_value=lambda *args, **kwargs: black.Changed.YES,
1194         ) as fsts:
1195             report = MagicMock()
1196             # Even with an existing file, since we are forcing stdin, black
1197             # should output to stdout and not modify the file inplace
1198             p = Path(str(THIS_DIR / "data/collections.py"))
1199             # Make sure is_file actually returns True
1200             self.assertTrue(p.is_file())
1201             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1202             expected = Path(p)
1203             black.reformat_one(
1204                 path,
1205                 fast=True,
1206                 write_back=black.WriteBack.YES,
1207                 mode=DEFAULT_MODE,
1208                 report=report,
1209             )
1210             fsts.assert_called_once()
1211             # __BLACK_STDIN_FILENAME__ should have been stripped
1212             report.done.assert_called_with(expected, black.Changed.YES)
1213
1214     def test_reformat_one_with_stdin_empty(self) -> None:
1215         output = io.StringIO()
1216         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1217             try:
1218                 black.format_stdin_to_stdout(
1219                     fast=True,
1220                     content="",
1221                     write_back=black.WriteBack.YES,
1222                     mode=DEFAULT_MODE,
1223                 )
1224             except io.UnsupportedOperation:
1225                 pass  # StringIO does not support detach
1226             assert output.getvalue() == ""
1227
1228     def test_invalid_cli_regex(self) -> None:
1229         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1230             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1231
1232     def test_required_version_matches_version(self) -> None:
1233         self.invokeBlack(
1234             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1235         )
1236
1237     def test_required_version_does_not_match_version(self) -> None:
1238         self.invokeBlack(
1239             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1240         )
1241
1242     def test_preserves_line_endings(self) -> None:
1243         with TemporaryDirectory() as workspace:
1244             test_file = Path(workspace) / "test.py"
1245             for nl in ["\n", "\r\n"]:
1246                 contents = nl.join(["def f(  ):", "    pass"])
1247                 test_file.write_bytes(contents.encode())
1248                 ff(test_file, write_back=black.WriteBack.YES)
1249                 updated_contents: bytes = test_file.read_bytes()
1250                 self.assertIn(nl.encode(), updated_contents)
1251                 if nl == "\n":
1252                     self.assertNotIn(b"\r\n", updated_contents)
1253
1254     def test_preserves_line_endings_via_stdin(self) -> None:
1255         for nl in ["\n", "\r\n"]:
1256             contents = nl.join(["def f(  ):", "    pass"])
1257             runner = BlackRunner()
1258             result = runner.invoke(
1259                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1260             )
1261             self.assertEqual(result.exit_code, 0)
1262             output = result.stdout_bytes
1263             self.assertIn(nl.encode("utf8"), output)
1264             if nl == "\n":
1265                 self.assertNotIn(b"\r\n", output)
1266
1267     def test_assert_equivalent_different_asts(self) -> None:
1268         with self.assertRaises(AssertionError):
1269             black.assert_equivalent("{}", "None")
1270
1271     def test_shhh_click(self) -> None:
1272         try:
1273             from click import _unicodefun
1274         except ModuleNotFoundError:
1275             self.skipTest("Incompatible Click version")
1276         if not hasattr(_unicodefun, "_verify_python3_env"):
1277             self.skipTest("Incompatible Click version")
1278         # First, let's see if Click is crashing with a preferred ASCII charset.
1279         with patch("locale.getpreferredencoding") as gpe:
1280             gpe.return_value = "ASCII"
1281             with self.assertRaises(RuntimeError):
1282                 _unicodefun._verify_python3_env()  # type: ignore
1283         # Now, let's silence Click...
1284         black.patch_click()
1285         # ...and confirm it's silent.
1286         with patch("locale.getpreferredencoding") as gpe:
1287             gpe.return_value = "ASCII"
1288             try:
1289                 _unicodefun._verify_python3_env()  # type: ignore
1290             except RuntimeError as re:
1291                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1292
1293     def test_root_logger_not_used_directly(self) -> None:
1294         def fail(*args: Any, **kwargs: Any) -> None:
1295             self.fail("Record created with root logger")
1296
1297         with patch.multiple(
1298             logging.root,
1299             debug=fail,
1300             info=fail,
1301             warning=fail,
1302             error=fail,
1303             critical=fail,
1304             log=fail,
1305         ):
1306             ff(THIS_DIR / "util.py")
1307
1308     def test_invalid_config_return_code(self) -> None:
1309         tmp_file = Path(black.dump_to_file())
1310         try:
1311             tmp_config = Path(black.dump_to_file())
1312             tmp_config.unlink()
1313             args = ["--config", str(tmp_config), str(tmp_file)]
1314             self.invokeBlack(args, exit_code=2, ignore_config=False)
1315         finally:
1316             tmp_file.unlink()
1317
1318     def test_parse_pyproject_toml(self) -> None:
1319         test_toml_file = THIS_DIR / "test.toml"
1320         config = black.parse_pyproject_toml(str(test_toml_file))
1321         self.assertEqual(config["verbose"], 1)
1322         self.assertEqual(config["check"], "no")
1323         self.assertEqual(config["diff"], "y")
1324         self.assertEqual(config["color"], True)
1325         self.assertEqual(config["line_length"], 79)
1326         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1327         self.assertEqual(config["exclude"], r"\.pyi?$")
1328         self.assertEqual(config["include"], r"\.py?$")
1329
1330     def test_read_pyproject_toml(self) -> None:
1331         test_toml_file = THIS_DIR / "test.toml"
1332         fake_ctx = FakeContext()
1333         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1334         config = fake_ctx.default_map
1335         self.assertEqual(config["verbose"], "1")
1336         self.assertEqual(config["check"], "no")
1337         self.assertEqual(config["diff"], "y")
1338         self.assertEqual(config["color"], "True")
1339         self.assertEqual(config["line_length"], "79")
1340         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1341         self.assertEqual(config["exclude"], r"\.pyi?$")
1342         self.assertEqual(config["include"], r"\.py?$")
1343
1344     @pytest.mark.incompatible_with_mypyc
1345     def test_find_project_root(self) -> None:
1346         with TemporaryDirectory() as workspace:
1347             root = Path(workspace)
1348             test_dir = root / "test"
1349             test_dir.mkdir()
1350
1351             src_dir = root / "src"
1352             src_dir.mkdir()
1353
1354             root_pyproject = root / "pyproject.toml"
1355             root_pyproject.touch()
1356             src_pyproject = src_dir / "pyproject.toml"
1357             src_pyproject.touch()
1358             src_python = src_dir / "foo.py"
1359             src_python.touch()
1360
1361             self.assertEqual(
1362                 black.find_project_root((src_dir, test_dir)), root.resolve()
1363             )
1364             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1365             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1366
1367     @patch(
1368         "black.files.find_user_pyproject_toml",
1369         black.files.find_user_pyproject_toml.__wrapped__,
1370     )
1371     def test_find_user_pyproject_toml_linux(self) -> None:
1372         if system() == "Windows":
1373             return
1374
1375         # Test if XDG_CONFIG_HOME is checked
1376         with TemporaryDirectory() as workspace:
1377             tmp_user_config = Path(workspace) / "black"
1378             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1379                 self.assertEqual(
1380                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1381                 )
1382
1383         # Test fallback for XDG_CONFIG_HOME
1384         with patch.dict("os.environ"):
1385             os.environ.pop("XDG_CONFIG_HOME", None)
1386             fallback_user_config = Path("~/.config").expanduser() / "black"
1387             self.assertEqual(
1388                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1389             )
1390
1391     def test_find_user_pyproject_toml_windows(self) -> None:
1392         if system() != "Windows":
1393             return
1394
1395         user_config_path = Path.home() / ".black"
1396         self.assertEqual(
1397             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1398         )
1399
1400     def test_bpo_33660_workaround(self) -> None:
1401         if system() == "Windows":
1402             return
1403
1404         # https://bugs.python.org/issue33660
1405         root = Path("/")
1406         with change_directory(root):
1407             path = Path("workspace") / "project"
1408             report = black.Report(verbose=True)
1409             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1410             self.assertEqual(normalized_path, "workspace/project")
1411
1412     def test_newline_comment_interaction(self) -> None:
1413         source = "class A:\\\r\n# type: ignore\n pass\n"
1414         output = black.format_str(source, mode=DEFAULT_MODE)
1415         black.assert_stable(source, output, mode=DEFAULT_MODE)
1416
1417     def test_bpo_2142_workaround(self) -> None:
1418
1419         # https://bugs.python.org/issue2142
1420
1421         source, _ = read_data("missing_final_newline.py")
1422         # read_data adds a trailing newline
1423         source = source.rstrip()
1424         expected, _ = read_data("missing_final_newline.diff")
1425         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1426         diff_header = re.compile(
1427             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1428             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1429         )
1430         try:
1431             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1432             self.assertEqual(result.exit_code, 0)
1433         finally:
1434             os.unlink(tmp_file)
1435         actual = result.output
1436         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1437         self.assertEqual(actual, expected)
1438
1439     @pytest.mark.python2
1440     def test_docstring_reformat_for_py27(self) -> None:
1441         """
1442         Check that stripping trailing whitespace from Python 2 docstrings
1443         doesn't trigger a "not equivalent to source" error
1444         """
1445         source = (
1446             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1447         )
1448         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1449
1450         result = BlackRunner().invoke(
1451             black.main,
1452             ["-", "-q", "--target-version=py27"],
1453             input=BytesIO(source),
1454         )
1455
1456         self.assertEqual(result.exit_code, 0)
1457         actual = result.stdout
1458         self.assertFormatEqual(actual, expected)
1459
1460     @staticmethod
1461     def compare_results(
1462         result: click.testing.Result, expected_value: str, expected_exit_code: int
1463     ) -> None:
1464         """Helper method to test the value and exit code of a click Result."""
1465         assert (
1466             result.output == expected_value
1467         ), "The output did not match the expected value."
1468         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1469
1470     def test_code_option(self) -> None:
1471         """Test the code option with no changes."""
1472         code = 'print("Hello world")\n'
1473         args = ["--code", code]
1474         result = CliRunner().invoke(black.main, args)
1475
1476         self.compare_results(result, code, 0)
1477
1478     def test_code_option_changed(self) -> None:
1479         """Test the code option when changes are required."""
1480         code = "print('hello world')"
1481         formatted = black.format_str(code, mode=DEFAULT_MODE)
1482
1483         args = ["--code", code]
1484         result = CliRunner().invoke(black.main, args)
1485
1486         self.compare_results(result, formatted, 0)
1487
1488     def test_code_option_check(self) -> None:
1489         """Test the code option when check is passed."""
1490         args = ["--check", "--code", 'print("Hello world")\n']
1491         result = CliRunner().invoke(black.main, args)
1492         self.compare_results(result, "", 0)
1493
1494     def test_code_option_check_changed(self) -> None:
1495         """Test the code option when changes are required, and check is passed."""
1496         args = ["--check", "--code", "print('hello world')"]
1497         result = CliRunner().invoke(black.main, args)
1498         self.compare_results(result, "", 1)
1499
1500     def test_code_option_diff(self) -> None:
1501         """Test the code option when diff is passed."""
1502         code = "print('hello world')"
1503         formatted = black.format_str(code, mode=DEFAULT_MODE)
1504         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1505
1506         args = ["--diff", "--code", code]
1507         result = CliRunner().invoke(black.main, args)
1508
1509         # Remove time from diff
1510         output = DIFF_TIME.sub("", result.output)
1511
1512         assert output == result_diff, "The output did not match the expected value."
1513         assert result.exit_code == 0, "The exit code is incorrect."
1514
1515     def test_code_option_color_diff(self) -> None:
1516         """Test the code option when color and diff are passed."""
1517         code = "print('hello world')"
1518         formatted = black.format_str(code, mode=DEFAULT_MODE)
1519
1520         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1521         result_diff = color_diff(result_diff)
1522
1523         args = ["--diff", "--color", "--code", code]
1524         result = CliRunner().invoke(black.main, args)
1525
1526         # Remove time from diff
1527         output = DIFF_TIME.sub("", result.output)
1528
1529         assert output == result_diff, "The output did not match the expected value."
1530         assert result.exit_code == 0, "The exit code is incorrect."
1531
1532     @pytest.mark.incompatible_with_mypyc
1533     def test_code_option_safe(self) -> None:
1534         """Test that the code option throws an error when the sanity checks fail."""
1535         # Patch black.assert_equivalent to ensure the sanity checks fail
1536         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1537             code = 'print("Hello world")'
1538             error_msg = f"{code}\nerror: cannot format <string>: \n"
1539
1540             args = ["--safe", "--code", code]
1541             result = CliRunner().invoke(black.main, args)
1542
1543             self.compare_results(result, error_msg, 123)
1544
1545     def test_code_option_fast(self) -> None:
1546         """Test that the code option ignores errors when the sanity checks fail."""
1547         # Patch black.assert_equivalent to ensure the sanity checks fail
1548         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1549             code = 'print("Hello world")'
1550             formatted = black.format_str(code, mode=DEFAULT_MODE)
1551
1552             args = ["--fast", "--code", code]
1553             result = CliRunner().invoke(black.main, args)
1554
1555             self.compare_results(result, formatted, 0)
1556
1557     @pytest.mark.incompatible_with_mypyc
1558     def test_code_option_config(self) -> None:
1559         """
1560         Test that the code option finds the pyproject.toml in the current directory.
1561         """
1562         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1563             args = ["--code", "print"]
1564             # This is the only directory known to contain a pyproject.toml
1565             with change_directory(PROJECT_ROOT):
1566                 CliRunner().invoke(black.main, args)
1567                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1568
1569             assert (
1570                 len(parse.mock_calls) >= 1
1571             ), "Expected config parse to be called with the current directory."
1572
1573             _, call_args, _ = parse.mock_calls[0]
1574             assert (
1575                 call_args[0].lower() == str(pyproject_path).lower()
1576             ), "Incorrect config loaded."
1577
1578     @pytest.mark.incompatible_with_mypyc
1579     def test_code_option_parent_config(self) -> None:
1580         """
1581         Test that the code option finds the pyproject.toml in the parent directory.
1582         """
1583         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1584             with change_directory(THIS_DIR):
1585                 args = ["--code", "print"]
1586                 CliRunner().invoke(black.main, args)
1587
1588                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1589                 assert (
1590                     len(parse.mock_calls) >= 1
1591                 ), "Expected config parse to be called with the current directory."
1592
1593                 _, call_args, _ = parse.mock_calls[0]
1594                 assert (
1595                     call_args[0].lower() == str(pyproject_path).lower()
1596                 ), "Incorrect config loaded."
1597
1598     def test_for_handled_unexpected_eof_error(self) -> None:
1599         """
1600         Test that an unexpected EOF SyntaxError is nicely presented.
1601         """
1602         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1603             black.lib2to3_parse("print(", {})
1604
1605         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1606
1607     def test_equivalency_ast_parse_failure_includes_error(self) -> None:
1608         with pytest.raises(AssertionError) as err:
1609             black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
1610
1611         err.match("--safe")
1612         # Unfortunately the SyntaxError message has changed in newer versions so we
1613         # can't match it directly.
1614         err.match("invalid character")
1615         err.match(r"\(<unknown>, line 1\)")
1616
1617
1618 class TestCaching:
1619     def test_cache_broken_file(self) -> None:
1620         mode = DEFAULT_MODE
1621         with cache_dir() as workspace:
1622             cache_file = get_cache_file(mode)
1623             cache_file.write_text("this is not a pickle")
1624             assert black.read_cache(mode) == {}
1625             src = (workspace / "test.py").resolve()
1626             src.write_text("print('hello')")
1627             invokeBlack([str(src)])
1628             cache = black.read_cache(mode)
1629             assert str(src) in cache
1630
1631     def test_cache_single_file_already_cached(self) -> None:
1632         mode = DEFAULT_MODE
1633         with cache_dir() as workspace:
1634             src = (workspace / "test.py").resolve()
1635             src.write_text("print('hello')")
1636             black.write_cache({}, [src], mode)
1637             invokeBlack([str(src)])
1638             assert src.read_text() == "print('hello')"
1639
1640     @event_loop()
1641     def test_cache_multiple_files(self) -> None:
1642         mode = DEFAULT_MODE
1643         with cache_dir() as workspace, patch(
1644             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1645         ):
1646             one = (workspace / "one.py").resolve()
1647             with one.open("w") as fobj:
1648                 fobj.write("print('hello')")
1649             two = (workspace / "two.py").resolve()
1650             with two.open("w") as fobj:
1651                 fobj.write("print('hello')")
1652             black.write_cache({}, [one], mode)
1653             invokeBlack([str(workspace)])
1654             with one.open("r") as fobj:
1655                 assert fobj.read() == "print('hello')"
1656             with two.open("r") as fobj:
1657                 assert fobj.read() == 'print("hello")\n'
1658             cache = black.read_cache(mode)
1659             assert str(one) in cache
1660             assert str(two) in cache
1661
1662     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1663     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1664         mode = DEFAULT_MODE
1665         with cache_dir() as workspace:
1666             src = (workspace / "test.py").resolve()
1667             with src.open("w") as fobj:
1668                 fobj.write("print('hello')")
1669             with patch("black.read_cache") as read_cache, patch(
1670                 "black.write_cache"
1671             ) as write_cache:
1672                 cmd = [str(src), "--diff"]
1673                 if color:
1674                     cmd.append("--color")
1675                 invokeBlack(cmd)
1676                 cache_file = get_cache_file(mode)
1677                 assert cache_file.exists() is False
1678                 write_cache.assert_not_called()
1679                 read_cache.assert_not_called()
1680
1681     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1682     @event_loop()
1683     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1684         with cache_dir() as workspace:
1685             for tag in range(0, 4):
1686                 src = (workspace / f"test{tag}.py").resolve()
1687                 with src.open("w") as fobj:
1688                     fobj.write("print('hello')")
1689             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1690                 cmd = ["--diff", str(workspace)]
1691                 if color:
1692                     cmd.append("--color")
1693                 invokeBlack(cmd, exit_code=0)
1694                 # this isn't quite doing what we want, but if it _isn't_
1695                 # called then we cannot be using the lock it provides
1696                 mgr.assert_called()
1697
1698     def test_no_cache_when_stdin(self) -> None:
1699         mode = DEFAULT_MODE
1700         with cache_dir():
1701             result = CliRunner().invoke(
1702                 black.main, ["-"], input=BytesIO(b"print('hello')")
1703             )
1704             assert not result.exit_code
1705             cache_file = get_cache_file(mode)
1706             assert not cache_file.exists()
1707
1708     def test_read_cache_no_cachefile(self) -> None:
1709         mode = DEFAULT_MODE
1710         with cache_dir():
1711             assert black.read_cache(mode) == {}
1712
1713     def test_write_cache_read_cache(self) -> None:
1714         mode = DEFAULT_MODE
1715         with cache_dir() as workspace:
1716             src = (workspace / "test.py").resolve()
1717             src.touch()
1718             black.write_cache({}, [src], mode)
1719             cache = black.read_cache(mode)
1720             assert str(src) in cache
1721             assert cache[str(src)] == black.get_cache_info(src)
1722
1723     def test_filter_cached(self) -> None:
1724         with TemporaryDirectory() as workspace:
1725             path = Path(workspace)
1726             uncached = (path / "uncached").resolve()
1727             cached = (path / "cached").resolve()
1728             cached_but_changed = (path / "changed").resolve()
1729             uncached.touch()
1730             cached.touch()
1731             cached_but_changed.touch()
1732             cache = {
1733                 str(cached): black.get_cache_info(cached),
1734                 str(cached_but_changed): (0.0, 0),
1735             }
1736             todo, done = black.filter_cached(
1737                 cache, {uncached, cached, cached_but_changed}
1738             )
1739             assert todo == {uncached, cached_but_changed}
1740             assert done == {cached}
1741
1742     def test_write_cache_creates_directory_if_needed(self) -> None:
1743         mode = DEFAULT_MODE
1744         with cache_dir(exists=False) as workspace:
1745             assert not workspace.exists()
1746             black.write_cache({}, [], mode)
1747             assert workspace.exists()
1748
1749     @event_loop()
1750     def test_failed_formatting_does_not_get_cached(self) -> None:
1751         mode = DEFAULT_MODE
1752         with cache_dir() as workspace, patch(
1753             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1754         ):
1755             failing = (workspace / "failing.py").resolve()
1756             with failing.open("w") as fobj:
1757                 fobj.write("not actually python")
1758             clean = (workspace / "clean.py").resolve()
1759             with clean.open("w") as fobj:
1760                 fobj.write('print("hello")\n')
1761             invokeBlack([str(workspace)], exit_code=123)
1762             cache = black.read_cache(mode)
1763             assert str(failing) not in cache
1764             assert str(clean) in cache
1765
1766     def test_write_cache_write_fail(self) -> None:
1767         mode = DEFAULT_MODE
1768         with cache_dir(), patch.object(Path, "open") as mock:
1769             mock.side_effect = OSError
1770             black.write_cache({}, [], mode)
1771
1772     def test_read_cache_line_lengths(self) -> None:
1773         mode = DEFAULT_MODE
1774         short_mode = replace(DEFAULT_MODE, line_length=1)
1775         with cache_dir() as workspace:
1776             path = (workspace / "file.py").resolve()
1777             path.touch()
1778             black.write_cache({}, [path], mode)
1779             one = black.read_cache(mode)
1780             assert str(path) in one
1781             two = black.read_cache(short_mode)
1782             assert str(path) not in two
1783
1784
1785 def assert_collected_sources(
1786     src: Sequence[Union[str, Path]],
1787     expected: Sequence[Union[str, Path]],
1788     *,
1789     exclude: Optional[str] = None,
1790     include: Optional[str] = None,
1791     extend_exclude: Optional[str] = None,
1792     force_exclude: Optional[str] = None,
1793     stdin_filename: Optional[str] = None,
1794 ) -> None:
1795     gs_src = tuple(str(Path(s)) for s in src)
1796     gs_expected = [Path(s) for s in expected]
1797     gs_exclude = None if exclude is None else compile_pattern(exclude)
1798     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1799     gs_extend_exclude = (
1800         None if extend_exclude is None else compile_pattern(extend_exclude)
1801     )
1802     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1803     collected = black.get_sources(
1804         ctx=FakeContext(),
1805         src=gs_src,
1806         quiet=False,
1807         verbose=False,
1808         include=gs_include,
1809         exclude=gs_exclude,
1810         extend_exclude=gs_extend_exclude,
1811         force_exclude=gs_force_exclude,
1812         report=black.Report(),
1813         stdin_filename=stdin_filename,
1814     )
1815     assert sorted(collected) == sorted(gs_expected)
1816
1817
1818 class TestFileCollection:
1819     def test_include_exclude(self) -> None:
1820         path = THIS_DIR / "data" / "include_exclude_tests"
1821         src = [path]
1822         expected = [
1823             Path(path / "b/dont_exclude/a.py"),
1824             Path(path / "b/dont_exclude/a.pyi"),
1825         ]
1826         assert_collected_sources(
1827             src,
1828             expected,
1829             include=r"\.pyi?$",
1830             exclude=r"/exclude/|/\.definitely_exclude/",
1831         )
1832
1833     def test_gitignore_used_as_default(self) -> None:
1834         base = Path(DATA_DIR / "include_exclude_tests")
1835         expected = [
1836             base / "b/.definitely_exclude/a.py",
1837             base / "b/.definitely_exclude/a.pyi",
1838         ]
1839         src = [base / "b/"]
1840         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1841
1842     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1843     def test_exclude_for_issue_1572(self) -> None:
1844         # Exclude shouldn't touch files that were explicitly given to Black through the
1845         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1846         # https://github.com/psf/black/issues/1572
1847         path = DATA_DIR / "include_exclude_tests"
1848         src = [path / "b/exclude/a.py"]
1849         expected = [path / "b/exclude/a.py"]
1850         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1851
1852     def test_gitignore_exclude(self) -> None:
1853         path = THIS_DIR / "data" / "include_exclude_tests"
1854         include = re.compile(r"\.pyi?$")
1855         exclude = re.compile(r"")
1856         report = black.Report()
1857         gitignore = PathSpec.from_lines(
1858             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1859         )
1860         sources: List[Path] = []
1861         expected = [
1862             Path(path / "b/dont_exclude/a.py"),
1863             Path(path / "b/dont_exclude/a.pyi"),
1864         ]
1865         this_abs = THIS_DIR.resolve()
1866         sources.extend(
1867             black.gen_python_files(
1868                 path.iterdir(),
1869                 this_abs,
1870                 include,
1871                 exclude,
1872                 None,
1873                 None,
1874                 report,
1875                 gitignore,
1876                 verbose=False,
1877                 quiet=False,
1878             )
1879         )
1880         assert sorted(expected) == sorted(sources)
1881
1882     def test_nested_gitignore(self) -> None:
1883         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1884         include = re.compile(r"\.pyi?$")
1885         exclude = re.compile(r"")
1886         root_gitignore = black.files.get_gitignore(path)
1887         report = black.Report()
1888         expected: List[Path] = [
1889             Path(path / "x.py"),
1890             Path(path / "root/b.py"),
1891             Path(path / "root/c.py"),
1892             Path(path / "root/child/c.py"),
1893         ]
1894         this_abs = THIS_DIR.resolve()
1895         sources = list(
1896             black.gen_python_files(
1897                 path.iterdir(),
1898                 this_abs,
1899                 include,
1900                 exclude,
1901                 None,
1902                 None,
1903                 report,
1904                 root_gitignore,
1905                 verbose=False,
1906                 quiet=False,
1907             )
1908         )
1909         assert sorted(expected) == sorted(sources)
1910
1911     def test_invalid_gitignore(self) -> None:
1912         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1913         empty_config = path / "pyproject.toml"
1914         result = BlackRunner().invoke(
1915             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1916         )
1917         assert result.exit_code == 1
1918         assert result.stderr_bytes is not None
1919
1920         gitignore = path / ".gitignore"
1921         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1922
1923     def test_invalid_nested_gitignore(self) -> None:
1924         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1925         empty_config = path / "pyproject.toml"
1926         result = BlackRunner().invoke(
1927             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1928         )
1929         assert result.exit_code == 1
1930         assert result.stderr_bytes is not None
1931
1932         gitignore = path / "a" / ".gitignore"
1933         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1934
1935     def test_empty_include(self) -> None:
1936         path = DATA_DIR / "include_exclude_tests"
1937         src = [path]
1938         expected = [
1939             Path(path / "b/exclude/a.pie"),
1940             Path(path / "b/exclude/a.py"),
1941             Path(path / "b/exclude/a.pyi"),
1942             Path(path / "b/dont_exclude/a.pie"),
1943             Path(path / "b/dont_exclude/a.py"),
1944             Path(path / "b/dont_exclude/a.pyi"),
1945             Path(path / "b/.definitely_exclude/a.pie"),
1946             Path(path / "b/.definitely_exclude/a.py"),
1947             Path(path / "b/.definitely_exclude/a.pyi"),
1948             Path(path / ".gitignore"),
1949             Path(path / "pyproject.toml"),
1950         ]
1951         # Setting exclude explicitly to an empty string to block .gitignore usage.
1952         assert_collected_sources(src, expected, include="", exclude="")
1953
1954     def test_extend_exclude(self) -> None:
1955         path = DATA_DIR / "include_exclude_tests"
1956         src = [path]
1957         expected = [
1958             Path(path / "b/exclude/a.py"),
1959             Path(path / "b/dont_exclude/a.py"),
1960         ]
1961         assert_collected_sources(
1962             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1963         )
1964
1965     @pytest.mark.incompatible_with_mypyc
1966     def test_symlink_out_of_root_directory(self) -> None:
1967         path = MagicMock()
1968         root = THIS_DIR.resolve()
1969         child = MagicMock()
1970         include = re.compile(black.DEFAULT_INCLUDES)
1971         exclude = re.compile(black.DEFAULT_EXCLUDES)
1972         report = black.Report()
1973         gitignore = PathSpec.from_lines("gitwildmatch", [])
1974         # `child` should behave like a symlink which resolved path is clearly
1975         # outside of the `root` directory.
1976         path.iterdir.return_value = [child]
1977         child.resolve.return_value = Path("/a/b/c")
1978         child.as_posix.return_value = "/a/b/c"
1979         child.is_symlink.return_value = True
1980         try:
1981             list(
1982                 black.gen_python_files(
1983                     path.iterdir(),
1984                     root,
1985                     include,
1986                     exclude,
1987                     None,
1988                     None,
1989                     report,
1990                     gitignore,
1991                     verbose=False,
1992                     quiet=False,
1993                 )
1994             )
1995         except ValueError as ve:
1996             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1997         path.iterdir.assert_called_once()
1998         child.resolve.assert_called_once()
1999         child.is_symlink.assert_called_once()
2000         # `child` should behave like a strange file which resolved path is clearly
2001         # outside of the `root` directory.
2002         child.is_symlink.return_value = False
2003         with pytest.raises(ValueError):
2004             list(
2005                 black.gen_python_files(
2006                     path.iterdir(),
2007                     root,
2008                     include,
2009                     exclude,
2010                     None,
2011                     None,
2012                     report,
2013                     gitignore,
2014                     verbose=False,
2015                     quiet=False,
2016                 )
2017             )
2018         path.iterdir.assert_called()
2019         assert path.iterdir.call_count == 2
2020         child.resolve.assert_called()
2021         assert child.resolve.call_count == 2
2022         child.is_symlink.assert_called()
2023         assert child.is_symlink.call_count == 2
2024
2025     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2026     def test_get_sources_with_stdin(self) -> None:
2027         src = ["-"]
2028         expected = ["-"]
2029         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2030
2031     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2032     def test_get_sources_with_stdin_filename(self) -> None:
2033         src = ["-"]
2034         stdin_filename = str(THIS_DIR / "data/collections.py")
2035         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2036         assert_collected_sources(
2037             src,
2038             expected,
2039             exclude=r"/exclude/a\.py",
2040             stdin_filename=stdin_filename,
2041         )
2042
2043     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2044     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2045         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2046         # file being passed directly. This is the same as
2047         # test_exclude_for_issue_1572
2048         path = DATA_DIR / "include_exclude_tests"
2049         src = ["-"]
2050         stdin_filename = str(path / "b/exclude/a.py")
2051         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2052         assert_collected_sources(
2053             src,
2054             expected,
2055             exclude=r"/exclude/|a\.py",
2056             stdin_filename=stdin_filename,
2057         )
2058
2059     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2060     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2061         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2062         # file being passed directly. This is the same as
2063         # test_exclude_for_issue_1572
2064         src = ["-"]
2065         path = THIS_DIR / "data" / "include_exclude_tests"
2066         stdin_filename = str(path / "b/exclude/a.py")
2067         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2068         assert_collected_sources(
2069             src,
2070             expected,
2071             extend_exclude=r"/exclude/|a\.py",
2072             stdin_filename=stdin_filename,
2073         )
2074
2075     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2076     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2077         # Force exclude should exclude the file when passing it through
2078         # stdin_filename
2079         path = THIS_DIR / "data" / "include_exclude_tests"
2080         stdin_filename = str(path / "b/exclude/a.py")
2081         assert_collected_sources(
2082             src=["-"],
2083             expected=[],
2084             force_exclude=r"/exclude/|a\.py",
2085             stdin_filename=stdin_filename,
2086         )
2087
2088
2089 @pytest.mark.python2
2090 @pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
2091 def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
2092     args = [
2093         "--config",
2094         str(THIS_DIR / "empty.toml"),
2095         str(DATA_DIR / "python2.py"),
2096         "--check",
2097     ]
2098     if explicit:
2099         args.append("--target-version=py27")
2100     with cache_dir():
2101         result = BlackRunner().invoke(black.main, args)
2102     assert "DEPRECATION: Python 2 support will be removed" in result.stderr
2103
2104
2105 @pytest.mark.python2
2106 def test_python_2_deprecation_autodetection_extended() -> None:
2107     # this test has a similar construction to test_get_features_used_decorator
2108     python2, non_python2 = read_data("python2_detection")
2109     for python2_case in python2.split("###"):
2110         node = black.lib2to3_parse(python2_case)
2111         assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
2112     for non_python2_case in non_python2.split("###"):
2113         node = black.lib2to3_parse(non_python2_case)
2114         assert black.detect_target_versions(node) != {
2115             TargetVersion.PY27
2116         }, non_python2_case
2117
2118
2119 try:
2120     with open(black.__file__, "r", encoding="utf-8") as _bf:
2121         black_source_lines = _bf.readlines()
2122 except UnicodeDecodeError:
2123     if not black.COMPILED:
2124         raise
2125
2126
2127 def tracefunc(
2128     frame: types.FrameType, event: str, arg: Any
2129 ) -> Callable[[types.FrameType, str, Any], Any]:
2130     """Show function calls `from black/__init__.py` as they happen.
2131
2132     Register this with `sys.settrace()` in a test you're debugging.
2133     """
2134     if event != "call":
2135         return tracefunc
2136
2137     stack = len(inspect.stack()) - 19
2138     stack *= 2
2139     filename = frame.f_code.co_filename
2140     lineno = frame.f_lineno
2141     func_sig_lineno = lineno - 1
2142     funcname = black_source_lines[func_sig_lineno].strip()
2143     while funcname.startswith("@"):
2144         func_sig_lineno += 1
2145         funcname = black_source_lines[func_sig_lineno].strip()
2146     if "black/__init__.py" in filename:
2147         print(f"{' ' * stack}{lineno}:{funcname}")
2148     return tracefunc