]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump regex dependency to 2021.4.4 to fix import of Pattern class (#2621)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import regex as re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args, catch_exceptions=False)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1;37m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY27})
728         black.lib2to3_parse(straddling, {TargetVersion.PY36})
729         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
730
731         py2_only = "print x"
732         black.lib2to3_parse(py2_only)
733         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
734         with self.assertRaises(black.InvalidInput):
735             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
736         with self.assertRaises(black.InvalidInput):
737             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
738
739         py3_only = "exec(x, end=y)"
740         black.lib2to3_parse(py3_only)
741         with self.assertRaises(black.InvalidInput):
742             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
743         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
744         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
745
746     def test_get_features_used_decorator(self) -> None:
747         # Test the feature detection of new decorator syntax
748         # since this makes some test cases of test_get_features_used()
749         # fails if it fails, this is tested first so that a useful case
750         # is identified
751         simples, relaxed = read_data("decorators")
752         # skip explanation comments at the top of the file
753         for simple_test in simples.split("##")[1:]:
754             node = black.lib2to3_parse(simple_test)
755             decorator = str(node.children[0].children[0]).strip()
756             self.assertNotIn(
757                 Feature.RELAXED_DECORATORS,
758                 black.get_features_used(node),
759                 msg=(
760                     f"decorator '{decorator}' follows python<=3.8 syntax"
761                     "but is detected as 3.9+"
762                     # f"The full node is\n{node!r}"
763                 ),
764             )
765         # skip the '# output' comment at the top of the output part
766         for relaxed_test in relaxed.split("##")[1:]:
767             node = black.lib2to3_parse(relaxed_test)
768             decorator = str(node.children[0].children[0]).strip()
769             self.assertIn(
770                 Feature.RELAXED_DECORATORS,
771                 black.get_features_used(node),
772                 msg=(
773                     f"decorator '{decorator}' uses python3.9+ syntax"
774                     "but is detected as python<=3.8"
775                     # f"The full node is\n{node!r}"
776                 ),
777             )
778
779     def test_get_features_used(self) -> None:
780         node = black.lib2to3_parse("def f(*, arg): ...\n")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("def f(*, arg,): ...\n")
783         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
784         node = black.lib2to3_parse("f(*arg,)\n")
785         self.assertEqual(
786             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
787         )
788         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
789         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
790         node = black.lib2to3_parse("123_456\n")
791         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
792         node = black.lib2to3_parse("123456\n")
793         self.assertEqual(black.get_features_used(node), set())
794         source, expected = read_data("function")
795         node = black.lib2to3_parse(source)
796         expected_features = {
797             Feature.TRAILING_COMMA_IN_CALL,
798             Feature.TRAILING_COMMA_IN_DEF,
799             Feature.F_STRINGS,
800         }
801         self.assertEqual(black.get_features_used(node), expected_features)
802         node = black.lib2to3_parse(expected)
803         self.assertEqual(black.get_features_used(node), expected_features)
804         source, expected = read_data("expression")
805         node = black.lib2to3_parse(source)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse(expected)
808         self.assertEqual(black.get_features_used(node), set())
809         node = black.lib2to3_parse("lambda a, /, b: ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(a, /, b): ...")
812         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
813
814     def test_get_future_imports(self) -> None:
815         node = black.lib2to3_parse("\n")
816         self.assertEqual(set(), black.get_future_imports(node))
817         node = black.lib2to3_parse("from __future__ import black\n")
818         self.assertEqual({"black"}, black.get_future_imports(node))
819         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
820         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
821         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
822         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
823         node = black.lib2to3_parse(
824             "from __future__ import multiple\nfrom __future__ import imports\n"
825         )
826         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
827         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
828         self.assertEqual({"black"}, black.get_future_imports(node))
829         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
830         self.assertEqual({"black"}, black.get_future_imports(node))
831         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
832         self.assertEqual(set(), black.get_future_imports(node))
833         node = black.lib2to3_parse("from some.module import black\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse(
836             "from __future__ import unicode_literals as _unicode_literals"
837         )
838         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
839         node = black.lib2to3_parse(
840             "from __future__ import unicode_literals as _lol, print"
841         )
842         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
843
844     @pytest.mark.incompatible_with_mypyc
845     def test_debug_visitor(self) -> None:
846         source, _ = read_data("debug_visitor.py")
847         expected, _ = read_data("debug_visitor.out")
848         out_lines = []
849         err_lines = []
850
851         def out(msg: str, **kwargs: Any) -> None:
852             out_lines.append(msg)
853
854         def err(msg: str, **kwargs: Any) -> None:
855             err_lines.append(msg)
856
857         with patch("black.debug.out", out):
858             DebugVisitor.show(source)
859         actual = "\n".join(out_lines) + "\n"
860         log_name = ""
861         if expected != actual:
862             log_name = black.dump_to_file(*out_lines)
863         self.assertEqual(
864             expected,
865             actual,
866             f"AST print out is different. Actual version dumped to {log_name}",
867         )
868
869     def test_format_file_contents(self) -> None:
870         empty = ""
871         mode = DEFAULT_MODE
872         with self.assertRaises(black.NothingChanged):
873             black.format_file_contents(empty, mode=mode, fast=False)
874         just_nl = "\n"
875         with self.assertRaises(black.NothingChanged):
876             black.format_file_contents(just_nl, mode=mode, fast=False)
877         same = "j = [1, 2, 3]\n"
878         with self.assertRaises(black.NothingChanged):
879             black.format_file_contents(same, mode=mode, fast=False)
880         different = "j = [1,2,3]"
881         expected = same
882         actual = black.format_file_contents(different, mode=mode, fast=False)
883         self.assertEqual(expected, actual)
884         invalid = "return if you can"
885         with self.assertRaises(black.InvalidInput) as e:
886             black.format_file_contents(invalid, mode=mode, fast=False)
887         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
888
889     def test_endmarker(self) -> None:
890         n = black.lib2to3_parse("\n")
891         self.assertEqual(n.type, black.syms.file_input)
892         self.assertEqual(len(n.children), 1)
893         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
894
895     @pytest.mark.incompatible_with_mypyc
896     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
897     def test_assertFormatEqual(self) -> None:
898         out_lines = []
899         err_lines = []
900
901         def out(msg: str, **kwargs: Any) -> None:
902             out_lines.append(msg)
903
904         def err(msg: str, **kwargs: Any) -> None:
905             err_lines.append(msg)
906
907         with patch("black.output._out", out), patch("black.output._err", err):
908             with self.assertRaises(AssertionError):
909                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
910
911         out_str = "".join(out_lines)
912         self.assertTrue("Expected tree:" in out_str)
913         self.assertTrue("Actual tree:" in out_str)
914         self.assertEqual("".join(err_lines), "")
915
916     @event_loop()
917     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
918     def test_works_in_mono_process_only_environment(self) -> None:
919         with cache_dir() as workspace:
920             for f in [
921                 (workspace / "one.py").resolve(),
922                 (workspace / "two.py").resolve(),
923             ]:
924                 f.write_text('print("hello")\n')
925             self.invokeBlack([str(workspace)])
926
927     @event_loop()
928     def test_check_diff_use_together(self) -> None:
929         with cache_dir():
930             # Files which will be reformatted.
931             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
932             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
933             # Files which will not be reformatted.
934             src2 = (THIS_DIR / "data" / "composition.py").resolve()
935             self.invokeBlack([str(src2), "--diff", "--check"])
936             # Multi file command.
937             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
938
939     def test_no_files(self) -> None:
940         with cache_dir():
941             # Without an argument, black exits with error code 0.
942             self.invokeBlack([])
943
944     def test_broken_symlink(self) -> None:
945         with cache_dir() as workspace:
946             symlink = workspace / "broken_link.py"
947             try:
948                 symlink.symlink_to("nonexistent.py")
949             except (OSError, NotImplementedError) as e:
950                 self.skipTest(f"Can't create symlinks: {e}")
951             self.invokeBlack([str(workspace.resolve())])
952
953     def test_single_file_force_pyi(self) -> None:
954         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
955         contents, expected = read_data("force_pyi")
956         with cache_dir() as workspace:
957             path = (workspace / "file.py").resolve()
958             with open(path, "w") as fh:
959                 fh.write(contents)
960             self.invokeBlack([str(path), "--pyi"])
961             with open(path, "r") as fh:
962                 actual = fh.read()
963             # verify cache with --pyi is separate
964             pyi_cache = black.read_cache(pyi_mode)
965             self.assertIn(str(path), pyi_cache)
966             normal_cache = black.read_cache(DEFAULT_MODE)
967             self.assertNotIn(str(path), normal_cache)
968         self.assertFormatEqual(expected, actual)
969         black.assert_equivalent(contents, actual)
970         black.assert_stable(contents, actual, pyi_mode)
971
972     @event_loop()
973     def test_multi_file_force_pyi(self) -> None:
974         reg_mode = DEFAULT_MODE
975         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
976         contents, expected = read_data("force_pyi")
977         with cache_dir() as workspace:
978             paths = [
979                 (workspace / "file1.py").resolve(),
980                 (workspace / "file2.py").resolve(),
981             ]
982             for path in paths:
983                 with open(path, "w") as fh:
984                     fh.write(contents)
985             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
986             for path in paths:
987                 with open(path, "r") as fh:
988                     actual = fh.read()
989                 self.assertEqual(actual, expected)
990             # verify cache with --pyi is separate
991             pyi_cache = black.read_cache(pyi_mode)
992             normal_cache = black.read_cache(reg_mode)
993             for path in paths:
994                 self.assertIn(str(path), pyi_cache)
995                 self.assertNotIn(str(path), normal_cache)
996
997     def test_pipe_force_pyi(self) -> None:
998         source, expected = read_data("force_pyi")
999         result = CliRunner().invoke(
1000             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1001         )
1002         self.assertEqual(result.exit_code, 0)
1003         actual = result.output
1004         self.assertFormatEqual(actual, expected)
1005
1006     def test_single_file_force_py36(self) -> None:
1007         reg_mode = DEFAULT_MODE
1008         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1009         source, expected = read_data("force_py36")
1010         with cache_dir() as workspace:
1011             path = (workspace / "file.py").resolve()
1012             with open(path, "w") as fh:
1013                 fh.write(source)
1014             self.invokeBlack([str(path), *PY36_ARGS])
1015             with open(path, "r") as fh:
1016                 actual = fh.read()
1017             # verify cache with --target-version is separate
1018             py36_cache = black.read_cache(py36_mode)
1019             self.assertIn(str(path), py36_cache)
1020             normal_cache = black.read_cache(reg_mode)
1021             self.assertNotIn(str(path), normal_cache)
1022         self.assertEqual(actual, expected)
1023
1024     @event_loop()
1025     def test_multi_file_force_py36(self) -> None:
1026         reg_mode = DEFAULT_MODE
1027         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1028         source, expected = read_data("force_py36")
1029         with cache_dir() as workspace:
1030             paths = [
1031                 (workspace / "file1.py").resolve(),
1032                 (workspace / "file2.py").resolve(),
1033             ]
1034             for path in paths:
1035                 with open(path, "w") as fh:
1036                     fh.write(source)
1037             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1038             for path in paths:
1039                 with open(path, "r") as fh:
1040                     actual = fh.read()
1041                 self.assertEqual(actual, expected)
1042             # verify cache with --target-version is separate
1043             pyi_cache = black.read_cache(py36_mode)
1044             normal_cache = black.read_cache(reg_mode)
1045             for path in paths:
1046                 self.assertIn(str(path), pyi_cache)
1047                 self.assertNotIn(str(path), normal_cache)
1048
1049     def test_pipe_force_py36(self) -> None:
1050         source, expected = read_data("force_py36")
1051         result = CliRunner().invoke(
1052             black.main,
1053             ["-", "-q", "--target-version=py36"],
1054             input=BytesIO(source.encode("utf8")),
1055         )
1056         self.assertEqual(result.exit_code, 0)
1057         actual = result.output
1058         self.assertFormatEqual(actual, expected)
1059
1060     @pytest.mark.incompatible_with_mypyc
1061     def test_reformat_one_with_stdin(self) -> None:
1062         with patch(
1063             "black.format_stdin_to_stdout",
1064             return_value=lambda *args, **kwargs: black.Changed.YES,
1065         ) as fsts:
1066             report = MagicMock()
1067             path = Path("-")
1068             black.reformat_one(
1069                 path,
1070                 fast=True,
1071                 write_back=black.WriteBack.YES,
1072                 mode=DEFAULT_MODE,
1073                 report=report,
1074             )
1075             fsts.assert_called_once()
1076             report.done.assert_called_with(path, black.Changed.YES)
1077
1078     @pytest.mark.incompatible_with_mypyc
1079     def test_reformat_one_with_stdin_filename(self) -> None:
1080         with patch(
1081             "black.format_stdin_to_stdout",
1082             return_value=lambda *args, **kwargs: black.Changed.YES,
1083         ) as fsts:
1084             report = MagicMock()
1085             p = "foo.py"
1086             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1087             expected = Path(p)
1088             black.reformat_one(
1089                 path,
1090                 fast=True,
1091                 write_back=black.WriteBack.YES,
1092                 mode=DEFAULT_MODE,
1093                 report=report,
1094             )
1095             fsts.assert_called_once_with(
1096                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1097             )
1098             # __BLACK_STDIN_FILENAME__ should have been stripped
1099             report.done.assert_called_with(expected, black.Changed.YES)
1100
1101     @pytest.mark.incompatible_with_mypyc
1102     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1103         with patch(
1104             "black.format_stdin_to_stdout",
1105             return_value=lambda *args, **kwargs: black.Changed.YES,
1106         ) as fsts:
1107             report = MagicMock()
1108             p = "foo.pyi"
1109             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1110             expected = Path(p)
1111             black.reformat_one(
1112                 path,
1113                 fast=True,
1114                 write_back=black.WriteBack.YES,
1115                 mode=DEFAULT_MODE,
1116                 report=report,
1117             )
1118             fsts.assert_called_once_with(
1119                 fast=True,
1120                 write_back=black.WriteBack.YES,
1121                 mode=replace(DEFAULT_MODE, is_pyi=True),
1122             )
1123             # __BLACK_STDIN_FILENAME__ should have been stripped
1124             report.done.assert_called_with(expected, black.Changed.YES)
1125
1126     @pytest.mark.incompatible_with_mypyc
1127     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1128         with patch(
1129             "black.format_stdin_to_stdout",
1130             return_value=lambda *args, **kwargs: black.Changed.YES,
1131         ) as fsts:
1132             report = MagicMock()
1133             p = "foo.ipynb"
1134             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1135             expected = Path(p)
1136             black.reformat_one(
1137                 path,
1138                 fast=True,
1139                 write_back=black.WriteBack.YES,
1140                 mode=DEFAULT_MODE,
1141                 report=report,
1142             )
1143             fsts.assert_called_once_with(
1144                 fast=True,
1145                 write_back=black.WriteBack.YES,
1146                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1147             )
1148             # __BLACK_STDIN_FILENAME__ should have been stripped
1149             report.done.assert_called_with(expected, black.Changed.YES)
1150
1151     @pytest.mark.incompatible_with_mypyc
1152     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1153         with patch(
1154             "black.format_stdin_to_stdout",
1155             return_value=lambda *args, **kwargs: black.Changed.YES,
1156         ) as fsts:
1157             report = MagicMock()
1158             # Even with an existing file, since we are forcing stdin, black
1159             # should output to stdout and not modify the file inplace
1160             p = Path(str(THIS_DIR / "data/collections.py"))
1161             # Make sure is_file actually returns True
1162             self.assertTrue(p.is_file())
1163             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1164             expected = Path(p)
1165             black.reformat_one(
1166                 path,
1167                 fast=True,
1168                 write_back=black.WriteBack.YES,
1169                 mode=DEFAULT_MODE,
1170                 report=report,
1171             )
1172             fsts.assert_called_once()
1173             # __BLACK_STDIN_FILENAME__ should have been stripped
1174             report.done.assert_called_with(expected, black.Changed.YES)
1175
1176     def test_reformat_one_with_stdin_empty(self) -> None:
1177         output = io.StringIO()
1178         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1179             try:
1180                 black.format_stdin_to_stdout(
1181                     fast=True,
1182                     content="",
1183                     write_back=black.WriteBack.YES,
1184                     mode=DEFAULT_MODE,
1185                 )
1186             except io.UnsupportedOperation:
1187                 pass  # StringIO does not support detach
1188             assert output.getvalue() == ""
1189
1190     def test_invalid_cli_regex(self) -> None:
1191         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1192             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1193
1194     def test_required_version_matches_version(self) -> None:
1195         self.invokeBlack(
1196             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1197         )
1198
1199     def test_required_version_does_not_match_version(self) -> None:
1200         self.invokeBlack(
1201             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1202         )
1203
1204     def test_preserves_line_endings(self) -> None:
1205         with TemporaryDirectory() as workspace:
1206             test_file = Path(workspace) / "test.py"
1207             for nl in ["\n", "\r\n"]:
1208                 contents = nl.join(["def f(  ):", "    pass"])
1209                 test_file.write_bytes(contents.encode())
1210                 ff(test_file, write_back=black.WriteBack.YES)
1211                 updated_contents: bytes = test_file.read_bytes()
1212                 self.assertIn(nl.encode(), updated_contents)
1213                 if nl == "\n":
1214                     self.assertNotIn(b"\r\n", updated_contents)
1215
1216     def test_preserves_line_endings_via_stdin(self) -> None:
1217         for nl in ["\n", "\r\n"]:
1218             contents = nl.join(["def f(  ):", "    pass"])
1219             runner = BlackRunner()
1220             result = runner.invoke(
1221                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1222             )
1223             self.assertEqual(result.exit_code, 0)
1224             output = result.stdout_bytes
1225             self.assertIn(nl.encode("utf8"), output)
1226             if nl == "\n":
1227                 self.assertNotIn(b"\r\n", output)
1228
1229     def test_assert_equivalent_different_asts(self) -> None:
1230         with self.assertRaises(AssertionError):
1231             black.assert_equivalent("{}", "None")
1232
1233     def test_shhh_click(self) -> None:
1234         try:
1235             from click import _unicodefun
1236         except ModuleNotFoundError:
1237             self.skipTest("Incompatible Click version")
1238         if not hasattr(_unicodefun, "_verify_python3_env"):
1239             self.skipTest("Incompatible Click version")
1240         # First, let's see if Click is crashing with a preferred ASCII charset.
1241         with patch("locale.getpreferredencoding") as gpe:
1242             gpe.return_value = "ASCII"
1243             with self.assertRaises(RuntimeError):
1244                 _unicodefun._verify_python3_env()  # type: ignore
1245         # Now, let's silence Click...
1246         black.patch_click()
1247         # ...and confirm it's silent.
1248         with patch("locale.getpreferredencoding") as gpe:
1249             gpe.return_value = "ASCII"
1250             try:
1251                 _unicodefun._verify_python3_env()  # type: ignore
1252             except RuntimeError as re:
1253                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1254
1255     def test_root_logger_not_used_directly(self) -> None:
1256         def fail(*args: Any, **kwargs: Any) -> None:
1257             self.fail("Record created with root logger")
1258
1259         with patch.multiple(
1260             logging.root,
1261             debug=fail,
1262             info=fail,
1263             warning=fail,
1264             error=fail,
1265             critical=fail,
1266             log=fail,
1267         ):
1268             ff(THIS_DIR / "util.py")
1269
1270     def test_invalid_config_return_code(self) -> None:
1271         tmp_file = Path(black.dump_to_file())
1272         try:
1273             tmp_config = Path(black.dump_to_file())
1274             tmp_config.unlink()
1275             args = ["--config", str(tmp_config), str(tmp_file)]
1276             self.invokeBlack(args, exit_code=2, ignore_config=False)
1277         finally:
1278             tmp_file.unlink()
1279
1280     def test_parse_pyproject_toml(self) -> None:
1281         test_toml_file = THIS_DIR / "test.toml"
1282         config = black.parse_pyproject_toml(str(test_toml_file))
1283         self.assertEqual(config["verbose"], 1)
1284         self.assertEqual(config["check"], "no")
1285         self.assertEqual(config["diff"], "y")
1286         self.assertEqual(config["color"], True)
1287         self.assertEqual(config["line_length"], 79)
1288         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1289         self.assertEqual(config["exclude"], r"\.pyi?$")
1290         self.assertEqual(config["include"], r"\.py?$")
1291
1292     def test_read_pyproject_toml(self) -> None:
1293         test_toml_file = THIS_DIR / "test.toml"
1294         fake_ctx = FakeContext()
1295         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1296         config = fake_ctx.default_map
1297         self.assertEqual(config["verbose"], "1")
1298         self.assertEqual(config["check"], "no")
1299         self.assertEqual(config["diff"], "y")
1300         self.assertEqual(config["color"], "True")
1301         self.assertEqual(config["line_length"], "79")
1302         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1303         self.assertEqual(config["exclude"], r"\.pyi?$")
1304         self.assertEqual(config["include"], r"\.py?$")
1305
1306     @pytest.mark.incompatible_with_mypyc
1307     def test_find_project_root(self) -> None:
1308         with TemporaryDirectory() as workspace:
1309             root = Path(workspace)
1310             test_dir = root / "test"
1311             test_dir.mkdir()
1312
1313             src_dir = root / "src"
1314             src_dir.mkdir()
1315
1316             root_pyproject = root / "pyproject.toml"
1317             root_pyproject.touch()
1318             src_pyproject = src_dir / "pyproject.toml"
1319             src_pyproject.touch()
1320             src_python = src_dir / "foo.py"
1321             src_python.touch()
1322
1323             self.assertEqual(
1324                 black.find_project_root((src_dir, test_dir)), root.resolve()
1325             )
1326             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1327             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1328
1329     @patch(
1330         "black.files.find_user_pyproject_toml",
1331         black.files.find_user_pyproject_toml.__wrapped__,
1332     )
1333     def test_find_user_pyproject_toml_linux(self) -> None:
1334         if system() == "Windows":
1335             return
1336
1337         # Test if XDG_CONFIG_HOME is checked
1338         with TemporaryDirectory() as workspace:
1339             tmp_user_config = Path(workspace) / "black"
1340             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1341                 self.assertEqual(
1342                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1343                 )
1344
1345         # Test fallback for XDG_CONFIG_HOME
1346         with patch.dict("os.environ"):
1347             os.environ.pop("XDG_CONFIG_HOME", None)
1348             fallback_user_config = Path("~/.config").expanduser() / "black"
1349             self.assertEqual(
1350                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1351             )
1352
1353     def test_find_user_pyproject_toml_windows(self) -> None:
1354         if system() != "Windows":
1355             return
1356
1357         user_config_path = Path.home() / ".black"
1358         self.assertEqual(
1359             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1360         )
1361
1362     def test_bpo_33660_workaround(self) -> None:
1363         if system() == "Windows":
1364             return
1365
1366         # https://bugs.python.org/issue33660
1367         root = Path("/")
1368         with change_directory(root):
1369             path = Path("workspace") / "project"
1370             report = black.Report(verbose=True)
1371             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1372             self.assertEqual(normalized_path, "workspace/project")
1373
1374     def test_newline_comment_interaction(self) -> None:
1375         source = "class A:\\\r\n# type: ignore\n pass\n"
1376         output = black.format_str(source, mode=DEFAULT_MODE)
1377         black.assert_stable(source, output, mode=DEFAULT_MODE)
1378
1379     def test_bpo_2142_workaround(self) -> None:
1380
1381         # https://bugs.python.org/issue2142
1382
1383         source, _ = read_data("missing_final_newline.py")
1384         # read_data adds a trailing newline
1385         source = source.rstrip()
1386         expected, _ = read_data("missing_final_newline.diff")
1387         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1388         diff_header = re.compile(
1389             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1390             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1391         )
1392         try:
1393             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1394             self.assertEqual(result.exit_code, 0)
1395         finally:
1396             os.unlink(tmp_file)
1397         actual = result.output
1398         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1399         self.assertEqual(actual, expected)
1400
1401     @pytest.mark.python2
1402     def test_docstring_reformat_for_py27(self) -> None:
1403         """
1404         Check that stripping trailing whitespace from Python 2 docstrings
1405         doesn't trigger a "not equivalent to source" error
1406         """
1407         source = (
1408             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1409         )
1410         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1411
1412         result = BlackRunner().invoke(
1413             black.main,
1414             ["-", "-q", "--target-version=py27"],
1415             input=BytesIO(source),
1416         )
1417
1418         self.assertEqual(result.exit_code, 0)
1419         actual = result.stdout
1420         self.assertFormatEqual(actual, expected)
1421
1422     @staticmethod
1423     def compare_results(
1424         result: click.testing.Result, expected_value: str, expected_exit_code: int
1425     ) -> None:
1426         """Helper method to test the value and exit code of a click Result."""
1427         assert (
1428             result.output == expected_value
1429         ), "The output did not match the expected value."
1430         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1431
1432     def test_code_option(self) -> None:
1433         """Test the code option with no changes."""
1434         code = 'print("Hello world")\n'
1435         args = ["--code", code]
1436         result = CliRunner().invoke(black.main, args)
1437
1438         self.compare_results(result, code, 0)
1439
1440     def test_code_option_changed(self) -> None:
1441         """Test the code option when changes are required."""
1442         code = "print('hello world')"
1443         formatted = black.format_str(code, mode=DEFAULT_MODE)
1444
1445         args = ["--code", code]
1446         result = CliRunner().invoke(black.main, args)
1447
1448         self.compare_results(result, formatted, 0)
1449
1450     def test_code_option_check(self) -> None:
1451         """Test the code option when check is passed."""
1452         args = ["--check", "--code", 'print("Hello world")\n']
1453         result = CliRunner().invoke(black.main, args)
1454         self.compare_results(result, "", 0)
1455
1456     def test_code_option_check_changed(self) -> None:
1457         """Test the code option when changes are required, and check is passed."""
1458         args = ["--check", "--code", "print('hello world')"]
1459         result = CliRunner().invoke(black.main, args)
1460         self.compare_results(result, "", 1)
1461
1462     def test_code_option_diff(self) -> None:
1463         """Test the code option when diff is passed."""
1464         code = "print('hello world')"
1465         formatted = black.format_str(code, mode=DEFAULT_MODE)
1466         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1467
1468         args = ["--diff", "--code", code]
1469         result = CliRunner().invoke(black.main, args)
1470
1471         # Remove time from diff
1472         output = DIFF_TIME.sub("", result.output)
1473
1474         assert output == result_diff, "The output did not match the expected value."
1475         assert result.exit_code == 0, "The exit code is incorrect."
1476
1477     def test_code_option_color_diff(self) -> None:
1478         """Test the code option when color and diff are passed."""
1479         code = "print('hello world')"
1480         formatted = black.format_str(code, mode=DEFAULT_MODE)
1481
1482         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1483         result_diff = color_diff(result_diff)
1484
1485         args = ["--diff", "--color", "--code", code]
1486         result = CliRunner().invoke(black.main, args)
1487
1488         # Remove time from diff
1489         output = DIFF_TIME.sub("", result.output)
1490
1491         assert output == result_diff, "The output did not match the expected value."
1492         assert result.exit_code == 0, "The exit code is incorrect."
1493
1494     @pytest.mark.incompatible_with_mypyc
1495     def test_code_option_safe(self) -> None:
1496         """Test that the code option throws an error when the sanity checks fail."""
1497         # Patch black.assert_equivalent to ensure the sanity checks fail
1498         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1499             code = 'print("Hello world")'
1500             error_msg = f"{code}\nerror: cannot format <string>: \n"
1501
1502             args = ["--safe", "--code", code]
1503             result = CliRunner().invoke(black.main, args)
1504
1505             self.compare_results(result, error_msg, 123)
1506
1507     def test_code_option_fast(self) -> None:
1508         """Test that the code option ignores errors when the sanity checks fail."""
1509         # Patch black.assert_equivalent to ensure the sanity checks fail
1510         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1511             code = 'print("Hello world")'
1512             formatted = black.format_str(code, mode=DEFAULT_MODE)
1513
1514             args = ["--fast", "--code", code]
1515             result = CliRunner().invoke(black.main, args)
1516
1517             self.compare_results(result, formatted, 0)
1518
1519     @pytest.mark.incompatible_with_mypyc
1520     def test_code_option_config(self) -> None:
1521         """
1522         Test that the code option finds the pyproject.toml in the current directory.
1523         """
1524         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1525             args = ["--code", "print"]
1526             # This is the only directory known to contain a pyproject.toml
1527             with change_directory(PROJECT_ROOT):
1528                 CliRunner().invoke(black.main, args)
1529                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1530
1531             assert (
1532                 len(parse.mock_calls) >= 1
1533             ), "Expected config parse to be called with the current directory."
1534
1535             _, call_args, _ = parse.mock_calls[0]
1536             assert (
1537                 call_args[0].lower() == str(pyproject_path).lower()
1538             ), "Incorrect config loaded."
1539
1540     @pytest.mark.incompatible_with_mypyc
1541     def test_code_option_parent_config(self) -> None:
1542         """
1543         Test that the code option finds the pyproject.toml in the parent directory.
1544         """
1545         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1546             with change_directory(THIS_DIR):
1547                 args = ["--code", "print"]
1548                 CliRunner().invoke(black.main, args)
1549
1550                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1551                 assert (
1552                     len(parse.mock_calls) >= 1
1553                 ), "Expected config parse to be called with the current directory."
1554
1555                 _, call_args, _ = parse.mock_calls[0]
1556                 assert (
1557                     call_args[0].lower() == str(pyproject_path).lower()
1558                 ), "Incorrect config loaded."
1559
1560
1561 class TestCaching:
1562     def test_cache_broken_file(self) -> None:
1563         mode = DEFAULT_MODE
1564         with cache_dir() as workspace:
1565             cache_file = get_cache_file(mode)
1566             cache_file.write_text("this is not a pickle")
1567             assert black.read_cache(mode) == {}
1568             src = (workspace / "test.py").resolve()
1569             src.write_text("print('hello')")
1570             invokeBlack([str(src)])
1571             cache = black.read_cache(mode)
1572             assert str(src) in cache
1573
1574     def test_cache_single_file_already_cached(self) -> None:
1575         mode = DEFAULT_MODE
1576         with cache_dir() as workspace:
1577             src = (workspace / "test.py").resolve()
1578             src.write_text("print('hello')")
1579             black.write_cache({}, [src], mode)
1580             invokeBlack([str(src)])
1581             assert src.read_text() == "print('hello')"
1582
1583     @event_loop()
1584     def test_cache_multiple_files(self) -> None:
1585         mode = DEFAULT_MODE
1586         with cache_dir() as workspace, patch(
1587             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1588         ):
1589             one = (workspace / "one.py").resolve()
1590             with one.open("w") as fobj:
1591                 fobj.write("print('hello')")
1592             two = (workspace / "two.py").resolve()
1593             with two.open("w") as fobj:
1594                 fobj.write("print('hello')")
1595             black.write_cache({}, [one], mode)
1596             invokeBlack([str(workspace)])
1597             with one.open("r") as fobj:
1598                 assert fobj.read() == "print('hello')"
1599             with two.open("r") as fobj:
1600                 assert fobj.read() == 'print("hello")\n'
1601             cache = black.read_cache(mode)
1602             assert str(one) in cache
1603             assert str(two) in cache
1604
1605     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1606     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1607         mode = DEFAULT_MODE
1608         with cache_dir() as workspace:
1609             src = (workspace / "test.py").resolve()
1610             with src.open("w") as fobj:
1611                 fobj.write("print('hello')")
1612             with patch("black.read_cache") as read_cache, patch(
1613                 "black.write_cache"
1614             ) as write_cache:
1615                 cmd = [str(src), "--diff"]
1616                 if color:
1617                     cmd.append("--color")
1618                 invokeBlack(cmd)
1619                 cache_file = get_cache_file(mode)
1620                 assert cache_file.exists() is False
1621                 write_cache.assert_not_called()
1622                 read_cache.assert_not_called()
1623
1624     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1625     @event_loop()
1626     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1627         with cache_dir() as workspace:
1628             for tag in range(0, 4):
1629                 src = (workspace / f"test{tag}.py").resolve()
1630                 with src.open("w") as fobj:
1631                     fobj.write("print('hello')")
1632             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1633                 cmd = ["--diff", str(workspace)]
1634                 if color:
1635                     cmd.append("--color")
1636                 invokeBlack(cmd, exit_code=0)
1637                 # this isn't quite doing what we want, but if it _isn't_
1638                 # called then we cannot be using the lock it provides
1639                 mgr.assert_called()
1640
1641     def test_no_cache_when_stdin(self) -> None:
1642         mode = DEFAULT_MODE
1643         with cache_dir():
1644             result = CliRunner().invoke(
1645                 black.main, ["-"], input=BytesIO(b"print('hello')")
1646             )
1647             assert not result.exit_code
1648             cache_file = get_cache_file(mode)
1649             assert not cache_file.exists()
1650
1651     def test_read_cache_no_cachefile(self) -> None:
1652         mode = DEFAULT_MODE
1653         with cache_dir():
1654             assert black.read_cache(mode) == {}
1655
1656     def test_write_cache_read_cache(self) -> None:
1657         mode = DEFAULT_MODE
1658         with cache_dir() as workspace:
1659             src = (workspace / "test.py").resolve()
1660             src.touch()
1661             black.write_cache({}, [src], mode)
1662             cache = black.read_cache(mode)
1663             assert str(src) in cache
1664             assert cache[str(src)] == black.get_cache_info(src)
1665
1666     def test_filter_cached(self) -> None:
1667         with TemporaryDirectory() as workspace:
1668             path = Path(workspace)
1669             uncached = (path / "uncached").resolve()
1670             cached = (path / "cached").resolve()
1671             cached_but_changed = (path / "changed").resolve()
1672             uncached.touch()
1673             cached.touch()
1674             cached_but_changed.touch()
1675             cache = {
1676                 str(cached): black.get_cache_info(cached),
1677                 str(cached_but_changed): (0.0, 0),
1678             }
1679             todo, done = black.filter_cached(
1680                 cache, {uncached, cached, cached_but_changed}
1681             )
1682             assert todo == {uncached, cached_but_changed}
1683             assert done == {cached}
1684
1685     def test_write_cache_creates_directory_if_needed(self) -> None:
1686         mode = DEFAULT_MODE
1687         with cache_dir(exists=False) as workspace:
1688             assert not workspace.exists()
1689             black.write_cache({}, [], mode)
1690             assert workspace.exists()
1691
1692     @event_loop()
1693     def test_failed_formatting_does_not_get_cached(self) -> None:
1694         mode = DEFAULT_MODE
1695         with cache_dir() as workspace, patch(
1696             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1697         ):
1698             failing = (workspace / "failing.py").resolve()
1699             with failing.open("w") as fobj:
1700                 fobj.write("not actually python")
1701             clean = (workspace / "clean.py").resolve()
1702             with clean.open("w") as fobj:
1703                 fobj.write('print("hello")\n')
1704             invokeBlack([str(workspace)], exit_code=123)
1705             cache = black.read_cache(mode)
1706             assert str(failing) not in cache
1707             assert str(clean) in cache
1708
1709     def test_write_cache_write_fail(self) -> None:
1710         mode = DEFAULT_MODE
1711         with cache_dir(), patch.object(Path, "open") as mock:
1712             mock.side_effect = OSError
1713             black.write_cache({}, [], mode)
1714
1715     def test_read_cache_line_lengths(self) -> None:
1716         mode = DEFAULT_MODE
1717         short_mode = replace(DEFAULT_MODE, line_length=1)
1718         with cache_dir() as workspace:
1719             path = (workspace / "file.py").resolve()
1720             path.touch()
1721             black.write_cache({}, [path], mode)
1722             one = black.read_cache(mode)
1723             assert str(path) in one
1724             two = black.read_cache(short_mode)
1725             assert str(path) not in two
1726
1727
1728 def assert_collected_sources(
1729     src: Sequence[Union[str, Path]],
1730     expected: Sequence[Union[str, Path]],
1731     *,
1732     exclude: Optional[str] = None,
1733     include: Optional[str] = None,
1734     extend_exclude: Optional[str] = None,
1735     force_exclude: Optional[str] = None,
1736     stdin_filename: Optional[str] = None,
1737 ) -> None:
1738     gs_src = tuple(str(Path(s)) for s in src)
1739     gs_expected = [Path(s) for s in expected]
1740     gs_exclude = None if exclude is None else compile_pattern(exclude)
1741     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1742     gs_extend_exclude = (
1743         None if extend_exclude is None else compile_pattern(extend_exclude)
1744     )
1745     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1746     collected = black.get_sources(
1747         ctx=FakeContext(),
1748         src=gs_src,
1749         quiet=False,
1750         verbose=False,
1751         include=gs_include,
1752         exclude=gs_exclude,
1753         extend_exclude=gs_extend_exclude,
1754         force_exclude=gs_force_exclude,
1755         report=black.Report(),
1756         stdin_filename=stdin_filename,
1757     )
1758     assert sorted(list(collected)) == sorted(gs_expected)
1759
1760
1761 class TestFileCollection:
1762     def test_include_exclude(self) -> None:
1763         path = THIS_DIR / "data" / "include_exclude_tests"
1764         src = [path]
1765         expected = [
1766             Path(path / "b/dont_exclude/a.py"),
1767             Path(path / "b/dont_exclude/a.pyi"),
1768         ]
1769         assert_collected_sources(
1770             src,
1771             expected,
1772             include=r"\.pyi?$",
1773             exclude=r"/exclude/|/\.definitely_exclude/",
1774         )
1775
1776     def test_gitignore_used_as_default(self) -> None:
1777         base = Path(DATA_DIR / "include_exclude_tests")
1778         expected = [
1779             base / "b/.definitely_exclude/a.py",
1780             base / "b/.definitely_exclude/a.pyi",
1781         ]
1782         src = [base / "b/"]
1783         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1784
1785     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1786     def test_exclude_for_issue_1572(self) -> None:
1787         # Exclude shouldn't touch files that were explicitly given to Black through the
1788         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1789         # https://github.com/psf/black/issues/1572
1790         path = DATA_DIR / "include_exclude_tests"
1791         src = [path / "b/exclude/a.py"]
1792         expected = [path / "b/exclude/a.py"]
1793         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1794
1795     def test_gitignore_exclude(self) -> None:
1796         path = THIS_DIR / "data" / "include_exclude_tests"
1797         include = re.compile(r"\.pyi?$")
1798         exclude = re.compile(r"")
1799         report = black.Report()
1800         gitignore = PathSpec.from_lines(
1801             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1802         )
1803         sources: List[Path] = []
1804         expected = [
1805             Path(path / "b/dont_exclude/a.py"),
1806             Path(path / "b/dont_exclude/a.pyi"),
1807         ]
1808         this_abs = THIS_DIR.resolve()
1809         sources.extend(
1810             black.gen_python_files(
1811                 path.iterdir(),
1812                 this_abs,
1813                 include,
1814                 exclude,
1815                 None,
1816                 None,
1817                 report,
1818                 gitignore,
1819                 verbose=False,
1820                 quiet=False,
1821             )
1822         )
1823         assert sorted(expected) == sorted(sources)
1824
1825     def test_nested_gitignore(self) -> None:
1826         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1827         include = re.compile(r"\.pyi?$")
1828         exclude = re.compile(r"")
1829         root_gitignore = black.files.get_gitignore(path)
1830         report = black.Report()
1831         expected: List[Path] = [
1832             Path(path / "x.py"),
1833             Path(path / "root/b.py"),
1834             Path(path / "root/c.py"),
1835             Path(path / "root/child/c.py"),
1836         ]
1837         this_abs = THIS_DIR.resolve()
1838         sources = list(
1839             black.gen_python_files(
1840                 path.iterdir(),
1841                 this_abs,
1842                 include,
1843                 exclude,
1844                 None,
1845                 None,
1846                 report,
1847                 root_gitignore,
1848                 verbose=False,
1849                 quiet=False,
1850             )
1851         )
1852         assert sorted(expected) == sorted(sources)
1853
1854     def test_invalid_gitignore(self) -> None:
1855         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1856         empty_config = path / "pyproject.toml"
1857         result = BlackRunner().invoke(
1858             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1859         )
1860         assert result.exit_code == 1
1861         assert result.stderr_bytes is not None
1862
1863         gitignore = path / ".gitignore"
1864         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1865
1866     def test_invalid_nested_gitignore(self) -> None:
1867         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1868         empty_config = path / "pyproject.toml"
1869         result = BlackRunner().invoke(
1870             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1871         )
1872         assert result.exit_code == 1
1873         assert result.stderr_bytes is not None
1874
1875         gitignore = path / "a" / ".gitignore"
1876         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1877
1878     def test_empty_include(self) -> None:
1879         path = DATA_DIR / "include_exclude_tests"
1880         src = [path]
1881         expected = [
1882             Path(path / "b/exclude/a.pie"),
1883             Path(path / "b/exclude/a.py"),
1884             Path(path / "b/exclude/a.pyi"),
1885             Path(path / "b/dont_exclude/a.pie"),
1886             Path(path / "b/dont_exclude/a.py"),
1887             Path(path / "b/dont_exclude/a.pyi"),
1888             Path(path / "b/.definitely_exclude/a.pie"),
1889             Path(path / "b/.definitely_exclude/a.py"),
1890             Path(path / "b/.definitely_exclude/a.pyi"),
1891             Path(path / ".gitignore"),
1892             Path(path / "pyproject.toml"),
1893         ]
1894         # Setting exclude explicitly to an empty string to block .gitignore usage.
1895         assert_collected_sources(src, expected, include="", exclude="")
1896
1897     def test_extend_exclude(self) -> None:
1898         path = DATA_DIR / "include_exclude_tests"
1899         src = [path]
1900         expected = [
1901             Path(path / "b/exclude/a.py"),
1902             Path(path / "b/dont_exclude/a.py"),
1903         ]
1904         assert_collected_sources(
1905             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1906         )
1907
1908     @pytest.mark.incompatible_with_mypyc
1909     def test_symlink_out_of_root_directory(self) -> None:
1910         path = MagicMock()
1911         root = THIS_DIR.resolve()
1912         child = MagicMock()
1913         include = re.compile(black.DEFAULT_INCLUDES)
1914         exclude = re.compile(black.DEFAULT_EXCLUDES)
1915         report = black.Report()
1916         gitignore = PathSpec.from_lines("gitwildmatch", [])
1917         # `child` should behave like a symlink which resolved path is clearly
1918         # outside of the `root` directory.
1919         path.iterdir.return_value = [child]
1920         child.resolve.return_value = Path("/a/b/c")
1921         child.as_posix.return_value = "/a/b/c"
1922         child.is_symlink.return_value = True
1923         try:
1924             list(
1925                 black.gen_python_files(
1926                     path.iterdir(),
1927                     root,
1928                     include,
1929                     exclude,
1930                     None,
1931                     None,
1932                     report,
1933                     gitignore,
1934                     verbose=False,
1935                     quiet=False,
1936                 )
1937             )
1938         except ValueError as ve:
1939             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1940         path.iterdir.assert_called_once()
1941         child.resolve.assert_called_once()
1942         child.is_symlink.assert_called_once()
1943         # `child` should behave like a strange file which resolved path is clearly
1944         # outside of the `root` directory.
1945         child.is_symlink.return_value = False
1946         with pytest.raises(ValueError):
1947             list(
1948                 black.gen_python_files(
1949                     path.iterdir(),
1950                     root,
1951                     include,
1952                     exclude,
1953                     None,
1954                     None,
1955                     report,
1956                     gitignore,
1957                     verbose=False,
1958                     quiet=False,
1959                 )
1960             )
1961         path.iterdir.assert_called()
1962         assert path.iterdir.call_count == 2
1963         child.resolve.assert_called()
1964         assert child.resolve.call_count == 2
1965         child.is_symlink.assert_called()
1966         assert child.is_symlink.call_count == 2
1967
1968     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1969     def test_get_sources_with_stdin(self) -> None:
1970         src = ["-"]
1971         expected = ["-"]
1972         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1973
1974     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1975     def test_get_sources_with_stdin_filename(self) -> None:
1976         src = ["-"]
1977         stdin_filename = str(THIS_DIR / "data/collections.py")
1978         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
1979         assert_collected_sources(
1980             src,
1981             expected,
1982             exclude=r"/exclude/a\.py",
1983             stdin_filename=stdin_filename,
1984         )
1985
1986     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1987     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1988         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1989         # file being passed directly. This is the same as
1990         # test_exclude_for_issue_1572
1991         path = DATA_DIR / "include_exclude_tests"
1992         src = ["-"]
1993         stdin_filename = str(path / "b/exclude/a.py")
1994         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
1995         assert_collected_sources(
1996             src,
1997             expected,
1998             exclude=r"/exclude/|a\.py",
1999             stdin_filename=stdin_filename,
2000         )
2001
2002     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2003     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2004         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2005         # file being passed directly. This is the same as
2006         # test_exclude_for_issue_1572
2007         src = ["-"]
2008         path = THIS_DIR / "data" / "include_exclude_tests"
2009         stdin_filename = str(path / "b/exclude/a.py")
2010         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2011         assert_collected_sources(
2012             src,
2013             expected,
2014             extend_exclude=r"/exclude/|a\.py",
2015             stdin_filename=stdin_filename,
2016         )
2017
2018     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2019     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2020         # Force exclude should exclude the file when passing it through
2021         # stdin_filename
2022         path = THIS_DIR / "data" / "include_exclude_tests"
2023         stdin_filename = str(path / "b/exclude/a.py")
2024         assert_collected_sources(
2025             src=["-"],
2026             expected=[],
2027             force_exclude=r"/exclude/|a\.py",
2028             stdin_filename=stdin_filename,
2029         )
2030
2031
2032 @pytest.mark.python2
2033 @pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
2034 def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
2035     args = [
2036         "--config",
2037         str(THIS_DIR / "empty.toml"),
2038         str(DATA_DIR / "python2.py"),
2039         "--check",
2040     ]
2041     if explicit:
2042         args.append("--target-version=py27")
2043     with cache_dir():
2044         result = BlackRunner().invoke(black.main, args)
2045     assert "DEPRECATION: Python 2 support will be removed" in result.stderr
2046
2047
2048 @pytest.mark.python2
2049 def test_python_2_deprecation_autodetection_extended() -> None:
2050     # this test has a similar construction to test_get_features_used_decorator
2051     python2, non_python2 = read_data("python2_detection")
2052     for python2_case in python2.split("###"):
2053         node = black.lib2to3_parse(python2_case)
2054         assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
2055     for non_python2_case in non_python2.split("###"):
2056         node = black.lib2to3_parse(non_python2_case)
2057         assert black.detect_target_versions(node) != {
2058             TargetVersion.PY27
2059         }, non_python2_case
2060
2061
2062 try:
2063     with open(black.__file__, "r", encoding="utf-8") as _bf:
2064         black_source_lines = _bf.readlines()
2065 except UnicodeDecodeError:
2066     if not black.COMPILED:
2067         raise
2068
2069
2070 def tracefunc(
2071     frame: types.FrameType, event: str, arg: Any
2072 ) -> Callable[[types.FrameType, str, Any], Any]:
2073     """Show function calls `from black/__init__.py` as they happen.
2074
2075     Register this with `sys.settrace()` in a test you're debugging.
2076     """
2077     if event != "call":
2078         return tracefunc
2079
2080     stack = len(inspect.stack()) - 19
2081     stack *= 2
2082     filename = frame.f_code.co_filename
2083     lineno = frame.f_lineno
2084     func_sig_lineno = lineno - 1
2085     funcname = black_source_lines[func_sig_lineno].strip()
2086     while funcname.startswith("@"):
2087         func_sig_lineno += 1
2088         funcname = black_source_lines[func_sig_lineno].strip()
2089     if "black/__init__.py" in filename:
2090         print(f"{' ' * stack}{lineno}:{funcname}")
2091     return tracefunc