]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Do not use gitignore if explicitly passing excludes (#2170)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import pytest
28 import unittest
29 from unittest.mock import patch, MagicMock
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37
38 from pathspec import PathSpec
39
40 # Import other test classes
41 from tests.util import (
42     THIS_DIR,
43     read_data,
44     DETERMINISTIC_HEADER,
45     BlackBaseTestCase,
46     DEFAULT_MODE,
47     fs,
48     ff,
49     dump_to_stderr,
50 )
51 from .test_primer import PrimerCLITests  # noqa: F401
52
53
54 THIS_FILE = Path(__file__)
55 PY36_VERSIONS = {
56     TargetVersion.PY36,
57     TargetVersion.PY37,
58     TargetVersion.PY38,
59     TargetVersion.PY39,
60 }
61 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
62 T = TypeVar("T")
63 R = TypeVar("R")
64
65
66 @contextmanager
67 def cache_dir(exists: bool = True) -> Iterator[Path]:
68     with TemporaryDirectory() as workspace:
69         cache_dir = Path(workspace)
70         if not exists:
71             cache_dir = cache_dir / "new"
72         with patch("black.CACHE_DIR", cache_dir):
73             yield cache_dir
74
75
76 @contextmanager
77 def event_loop() -> Iterator[None]:
78     policy = asyncio.get_event_loop_policy()
79     loop = policy.new_event_loop()
80     asyncio.set_event_loop(loop)
81     try:
82         yield
83
84     finally:
85         loop.close()
86
87
88 class FakeContext(click.Context):
89     """A fake click Context for when calling functions that need it."""
90
91     def __init__(self) -> None:
92         self.default_map: Dict[str, Any] = {}
93
94
95 class FakeParameter(click.Parameter):
96     """A fake click Parameter for when calling functions that need it."""
97
98     def __init__(self) -> None:
99         pass
100
101
102 class BlackRunner(CliRunner):
103     """Modify CliRunner so that stderr is not merged with stdout.
104
105     This is a hack that can be removed once we depend on Click 7.x"""
106
107     def __init__(self) -> None:
108         self.stderrbuf = BytesIO()
109         self.stdoutbuf = BytesIO()
110         self.stdout_bytes = b""
111         self.stderr_bytes = b""
112         super().__init__()
113
114     @contextmanager
115     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
116         with super().isolation(*args, **kwargs) as output:
117             try:
118                 hold_stderr = sys.stderr
119                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
120                 yield output
121             finally:
122                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
123                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
124                 sys.stderr = hold_stderr
125
126
127 class BlackTestCase(BlackBaseTestCase):
128     def invokeBlack(
129         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
130     ) -> None:
131         runner = BlackRunner()
132         if ignore_config:
133             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
134         result = runner.invoke(black.main, args)
135         self.assertEqual(
136             result.exit_code,
137             exit_code,
138             msg=(
139                 f"Failed with args: {args}\n"
140                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
141                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
142                 f"exception: {result.exception}"
143             ),
144         )
145
146     @patch("black.dump_to_file", dump_to_stderr)
147     def test_empty(self) -> None:
148         source = expected = ""
149         actual = fs(source)
150         self.assertFormatEqual(expected, actual)
151         black.assert_equivalent(source, actual)
152         black.assert_stable(source, actual, DEFAULT_MODE)
153
154     def test_empty_ff(self) -> None:
155         expected = ""
156         tmp_file = Path(black.dump_to_file())
157         try:
158             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
159             with open(tmp_file, encoding="utf8") as f:
160                 actual = f.read()
161         finally:
162             os.unlink(tmp_file)
163         self.assertFormatEqual(expected, actual)
164
165     def test_piping(self) -> None:
166         source, expected = read_data("src/black/__init__", data=False)
167         result = BlackRunner().invoke(
168             black.main,
169             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
170             input=BytesIO(source.encode("utf8")),
171         )
172         self.assertEqual(result.exit_code, 0)
173         self.assertFormatEqual(expected, result.output)
174         black.assert_equivalent(source, result.output)
175         black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("expression.py")
183         expected, _ = read_data("expression.diff")
184         config = THIS_DIR / "data" / "empty_pyproject.toml"
185         args = [
186             "-",
187             "--fast",
188             f"--line-length={black.DEFAULT_LINE_LENGTH}",
189             "--diff",
190             f"--config={config}",
191         ]
192         result = BlackRunner().invoke(
193             black.main, args, input=BytesIO(source.encode("utf8"))
194         )
195         self.assertEqual(result.exit_code, 0)
196         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
197         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
198         self.assertEqual(expected, actual)
199
200     def test_piping_diff_with_color(self) -> None:
201         source, _ = read_data("expression.py")
202         config = THIS_DIR / "data" / "empty_pyproject.toml"
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={config}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1;37m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     @unittest.expectedFailure
238     @patch("black.dump_to_file", dump_to_stderr)
239     def test_trailing_comma_optional_parens_stability1(self) -> None:
240         source, _expected = read_data("trailing_comma_optional_parens1")
241         actual = fs(source)
242         black.assert_stable(source, actual, DEFAULT_MODE)
243
244     @unittest.expectedFailure
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens2")
248         actual = fs(source)
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @unittest.expectedFailure
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability3(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens3")
255         actual = fs(source)
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens1")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens2")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     @patch("black.dump_to_file", dump_to_stderr)
271     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
272         source, _expected = read_data("trailing_comma_optional_parens3")
273         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
274         black.assert_stable(source, actual, DEFAULT_MODE)
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_pep_572(self) -> None:
278         source, expected = read_data("pep_572")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_stable(source, actual, DEFAULT_MODE)
282         if sys.version_info >= (3, 8):
283             black.assert_equivalent(source, actual)
284
285     @patch("black.dump_to_file", dump_to_stderr)
286     def test_pep_572_remove_parens(self) -> None:
287         source, expected = read_data("pep_572_remove_parens")
288         actual = fs(source)
289         self.assertFormatEqual(expected, actual)
290         black.assert_stable(source, actual, DEFAULT_MODE)
291         if sys.version_info >= (3, 8):
292             black.assert_equivalent(source, actual)
293
294     @patch("black.dump_to_file", dump_to_stderr)
295     def test_pep_572_do_not_remove_parens(self) -> None:
296         source, expected = read_data("pep_572_do_not_remove_parens")
297         # the AST safety checks will fail, but that's expected, just make sure no
298         # parentheses are touched
299         actual = black.format_str(source, mode=DEFAULT_MODE)
300         self.assertFormatEqual(expected, actual)
301
302     def test_pep_572_version_detection(self) -> None:
303         source, _ = read_data("pep_572")
304         root = black.lib2to3_parse(source)
305         features = black.get_features_used(root)
306         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
307         versions = black.detect_target_versions(root)
308         self.assertIn(black.TargetVersion.PY38, versions)
309
310     def test_expression_ff(self) -> None:
311         source, expected = read_data("expression")
312         tmp_file = Path(black.dump_to_file(source))
313         try:
314             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
315             with open(tmp_file, encoding="utf8") as f:
316                 actual = f.read()
317         finally:
318             os.unlink(tmp_file)
319         self.assertFormatEqual(expected, actual)
320         with patch("black.dump_to_file", dump_to_stderr):
321             black.assert_equivalent(source, actual)
322             black.assert_stable(source, actual, DEFAULT_MODE)
323
324     def test_expression_diff(self) -> None:
325         source, _ = read_data("expression.py")
326         config = THIS_DIR / "data" / "empty_pyproject.toml"
327         expected, _ = read_data("expression.diff")
328         tmp_file = Path(black.dump_to_file(source))
329         diff_header = re.compile(
330             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
331             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
332         )
333         try:
334             result = BlackRunner().invoke(
335                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
336             )
337             self.assertEqual(result.exit_code, 0)
338         finally:
339             os.unlink(tmp_file)
340         actual = result.output
341         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
342         if expected != actual:
343             dump = black.dump_to_file(actual)
344             msg = (
345                 "Expected diff isn't equal to the actual. If you made changes to"
346                 " expression.py and this is an anticipated difference, overwrite"
347                 f" tests/data/expression.diff with {dump}"
348             )
349             self.assertEqual(expected, actual, msg)
350
351     def test_expression_diff_with_color(self) -> None:
352         source, _ = read_data("expression.py")
353         config = THIS_DIR / "data" / "empty_pyproject.toml"
354         expected, _ = read_data("expression.diff")
355         tmp_file = Path(black.dump_to_file(source))
356         try:
357             result = BlackRunner().invoke(
358                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
359             )
360         finally:
361             os.unlink(tmp_file)
362         actual = result.output
363         # We check the contents of the diff in `test_expression_diff`. All
364         # we need to check here is that color codes exist in the result.
365         self.assertIn("\033[1;37m", actual)
366         self.assertIn("\033[36m", actual)
367         self.assertIn("\033[32m", actual)
368         self.assertIn("\033[31m", actual)
369         self.assertIn("\033[0m", actual)
370
371     @patch("black.dump_to_file", dump_to_stderr)
372     def test_pep_570(self) -> None:
373         source, expected = read_data("pep_570")
374         actual = fs(source)
375         self.assertFormatEqual(expected, actual)
376         black.assert_stable(source, actual, DEFAULT_MODE)
377         if sys.version_info >= (3, 8):
378             black.assert_equivalent(source, actual)
379
380     def test_detect_pos_only_arguments(self) -> None:
381         source, _ = read_data("pep_570")
382         root = black.lib2to3_parse(source)
383         features = black.get_features_used(root)
384         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
385         versions = black.detect_target_versions(root)
386         self.assertIn(black.TargetVersion.PY38, versions)
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_string_quotes(self) -> None:
390         source, expected = read_data("string_quotes")
391         mode = black.Mode(experimental_string_processing=True)
392         actual = fs(source, mode=mode)
393         self.assertFormatEqual(expected, actual)
394         black.assert_equivalent(source, actual)
395         black.assert_stable(source, actual, mode)
396         mode = replace(mode, string_normalization=False)
397         not_normalized = fs(source, mode=mode)
398         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
399         black.assert_equivalent(source, not_normalized)
400         black.assert_stable(source, not_normalized, mode=mode)
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_docstring_no_string_normalization(self) -> None:
404         """Like test_docstring but with string normalization off."""
405         source, expected = read_data("docstring_no_string_normalization")
406         mode = replace(DEFAULT_MODE, string_normalization=False)
407         actual = fs(source, mode=mode)
408         self.assertFormatEqual(expected, actual)
409         black.assert_equivalent(source, actual)
410         black.assert_stable(source, actual, mode)
411
412     def test_long_strings_flag_disabled(self) -> None:
413         """Tests for turning off the string processing logic."""
414         source, expected = read_data("long_strings_flag_disabled")
415         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
416         actual = fs(source, mode=mode)
417         self.assertFormatEqual(expected, actual)
418         black.assert_stable(expected, actual, mode)
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_numeric_literals(self) -> None:
422         source, expected = read_data("numeric_literals")
423         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
424         actual = fs(source, mode=mode)
425         self.assertFormatEqual(expected, actual)
426         black.assert_equivalent(source, actual)
427         black.assert_stable(source, actual, mode)
428
429     @patch("black.dump_to_file", dump_to_stderr)
430     def test_numeric_literals_ignoring_underscores(self) -> None:
431         source, expected = read_data("numeric_literals_skip_underscores")
432         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
433         actual = fs(source, mode=mode)
434         self.assertFormatEqual(expected, actual)
435         black.assert_equivalent(source, actual)
436         black.assert_stable(source, actual, mode)
437
438     def test_skip_magic_trailing_comma(self) -> None:
439         source, _ = read_data("expression.py")
440         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
441         tmp_file = Path(black.dump_to_file(source))
442         diff_header = re.compile(
443             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
444             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
445         )
446         try:
447             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
448             self.assertEqual(result.exit_code, 0)
449         finally:
450             os.unlink(tmp_file)
451         actual = result.output
452         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
453         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
454         if expected != actual:
455             dump = black.dump_to_file(actual)
456             msg = (
457                 "Expected diff isn't equal to the actual. If you made changes to"
458                 " expression.py and this is an anticipated difference, overwrite"
459                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
460             )
461             self.assertEqual(expected, actual, msg)
462
463     @pytest.mark.without_python2
464     def test_python2_should_fail_without_optional_install(self) -> None:
465         # python 3.7 and below will install typed-ast and will be able to parse Python 2
466         if sys.version_info < (3, 8):
467             return
468         source = "x = 1234l"
469         tmp_file = Path(black.dump_to_file(source))
470         try:
471             runner = BlackRunner()
472             result = runner.invoke(black.main, [str(tmp_file)])
473             self.assertEqual(result.exit_code, 123)
474         finally:
475             os.unlink(tmp_file)
476         actual = (
477             runner.stderr_bytes.decode()
478             .replace("\n", "")
479             .replace("\\n", "")
480             .replace("\\r", "")
481             .replace("\r", "")
482         )
483         msg = (
484             "The requested source code has invalid Python 3 syntax."
485             "If you are trying to format Python 2 files please reinstall Black"
486             " with the 'python2' extra: `python3 -m pip install black[python2]`."
487         )
488         self.assertIn(msg, actual)
489
490     @pytest.mark.python2
491     @patch("black.dump_to_file", dump_to_stderr)
492     def test_python2_print_function(self) -> None:
493         source, expected = read_data("python2_print_function")
494         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
495         actual = fs(source, mode=mode)
496         self.assertFormatEqual(expected, actual)
497         black.assert_equivalent(source, actual)
498         black.assert_stable(source, actual, mode)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_stub(self) -> None:
502         mode = replace(DEFAULT_MODE, is_pyi=True)
503         source, expected = read_data("stub.pyi")
504         actual = fs(source, mode=mode)
505         self.assertFormatEqual(expected, actual)
506         black.assert_stable(source, actual, mode)
507
508     @patch("black.dump_to_file", dump_to_stderr)
509     def test_async_as_identifier(self) -> None:
510         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
511         source, expected = read_data("async_as_identifier")
512         actual = fs(source)
513         self.assertFormatEqual(expected, actual)
514         major, minor = sys.version_info[:2]
515         if major < 3 or (major <= 3 and minor < 7):
516             black.assert_equivalent(source, actual)
517         black.assert_stable(source, actual, DEFAULT_MODE)
518         # ensure black can parse this when the target is 3.6
519         self.invokeBlack([str(source_path), "--target-version", "py36"])
520         # but not on 3.7, because async/await is no longer an identifier
521         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
522
523     @patch("black.dump_to_file", dump_to_stderr)
524     def test_python37(self) -> None:
525         source_path = (THIS_DIR / "data" / "python37.py").resolve()
526         source, expected = read_data("python37")
527         actual = fs(source)
528         self.assertFormatEqual(expected, actual)
529         major, minor = sys.version_info[:2]
530         if major > 3 or (major == 3 and minor >= 7):
531             black.assert_equivalent(source, actual)
532         black.assert_stable(source, actual, DEFAULT_MODE)
533         # ensure black can parse this when the target is 3.7
534         self.invokeBlack([str(source_path), "--target-version", "py37"])
535         # but not on 3.6, because we use async as a reserved keyword
536         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
537
538     @patch("black.dump_to_file", dump_to_stderr)
539     def test_python38(self) -> None:
540         source, expected = read_data("python38")
541         actual = fs(source)
542         self.assertFormatEqual(expected, actual)
543         major, minor = sys.version_info[:2]
544         if major > 3 or (major == 3 and minor >= 8):
545             black.assert_equivalent(source, actual)
546         black.assert_stable(source, actual, DEFAULT_MODE)
547
548     @patch("black.dump_to_file", dump_to_stderr)
549     def test_python39(self) -> None:
550         source, expected = read_data("python39")
551         actual = fs(source)
552         self.assertFormatEqual(expected, actual)
553         major, minor = sys.version_info[:2]
554         if major > 3 or (major == 3 and minor >= 9):
555             black.assert_equivalent(source, actual)
556         black.assert_stable(source, actual, DEFAULT_MODE)
557
558     def test_tab_comment_indentation(self) -> None:
559         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
560         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
561         self.assertFormatEqual(contents_spc, fs(contents_spc))
562         self.assertFormatEqual(contents_spc, fs(contents_tab))
563
564         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
565         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
566         self.assertFormatEqual(contents_spc, fs(contents_spc))
567         self.assertFormatEqual(contents_spc, fs(contents_tab))
568
569         # mixed tabs and spaces (valid Python 2 code)
570         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
571         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
572         self.assertFormatEqual(contents_spc, fs(contents_spc))
573         self.assertFormatEqual(contents_spc, fs(contents_tab))
574
575         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
576         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
577         self.assertFormatEqual(contents_spc, fs(contents_spc))
578         self.assertFormatEqual(contents_spc, fs(contents_tab))
579
580     def test_report_verbose(self) -> None:
581         report = black.Report(verbose=True)
582         out_lines = []
583         err_lines = []
584
585         def out(msg: str, **kwargs: Any) -> None:
586             out_lines.append(msg)
587
588         def err(msg: str, **kwargs: Any) -> None:
589             err_lines.append(msg)
590
591         with patch("black.out", out), patch("black.err", err):
592             report.done(Path("f1"), black.Changed.NO)
593             self.assertEqual(len(out_lines), 1)
594             self.assertEqual(len(err_lines), 0)
595             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
596             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
597             self.assertEqual(report.return_code, 0)
598             report.done(Path("f2"), black.Changed.YES)
599             self.assertEqual(len(out_lines), 2)
600             self.assertEqual(len(err_lines), 0)
601             self.assertEqual(out_lines[-1], "reformatted f2")
602             self.assertEqual(
603                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
604             )
605             report.done(Path("f3"), black.Changed.CACHED)
606             self.assertEqual(len(out_lines), 3)
607             self.assertEqual(len(err_lines), 0)
608             self.assertEqual(
609                 out_lines[-1], "f3 wasn't modified on disk since last run."
610             )
611             self.assertEqual(
612                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
613             )
614             self.assertEqual(report.return_code, 0)
615             report.check = True
616             self.assertEqual(report.return_code, 1)
617             report.check = False
618             report.failed(Path("e1"), "boom")
619             self.assertEqual(len(out_lines), 3)
620             self.assertEqual(len(err_lines), 1)
621             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
622             self.assertEqual(
623                 unstyle(str(report)),
624                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
625                 " reformat.",
626             )
627             self.assertEqual(report.return_code, 123)
628             report.done(Path("f3"), black.Changed.YES)
629             self.assertEqual(len(out_lines), 4)
630             self.assertEqual(len(err_lines), 1)
631             self.assertEqual(out_lines[-1], "reformatted f3")
632             self.assertEqual(
633                 unstyle(str(report)),
634                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
635                 " reformat.",
636             )
637             self.assertEqual(report.return_code, 123)
638             report.failed(Path("e2"), "boom")
639             self.assertEqual(len(out_lines), 4)
640             self.assertEqual(len(err_lines), 2)
641             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
642             self.assertEqual(
643                 unstyle(str(report)),
644                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
645                 " reformat.",
646             )
647             self.assertEqual(report.return_code, 123)
648             report.path_ignored(Path("wat"), "no match")
649             self.assertEqual(len(out_lines), 5)
650             self.assertEqual(len(err_lines), 2)
651             self.assertEqual(out_lines[-1], "wat ignored: no match")
652             self.assertEqual(
653                 unstyle(str(report)),
654                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
655                 " reformat.",
656             )
657             self.assertEqual(report.return_code, 123)
658             report.done(Path("f4"), black.Changed.NO)
659             self.assertEqual(len(out_lines), 6)
660             self.assertEqual(len(err_lines), 2)
661             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
662             self.assertEqual(
663                 unstyle(str(report)),
664                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
665                 " reformat.",
666             )
667             self.assertEqual(report.return_code, 123)
668             report.check = True
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
672                 " would fail to reformat.",
673             )
674             report.check = False
675             report.diff = True
676             self.assertEqual(
677                 unstyle(str(report)),
678                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
679                 " would fail to reformat.",
680             )
681
682     def test_report_quiet(self) -> None:
683         report = black.Report(quiet=True)
684         out_lines = []
685         err_lines = []
686
687         def out(msg: str, **kwargs: Any) -> None:
688             out_lines.append(msg)
689
690         def err(msg: str, **kwargs: Any) -> None:
691             err_lines.append(msg)
692
693         with patch("black.out", out), patch("black.err", err):
694             report.done(Path("f1"), black.Changed.NO)
695             self.assertEqual(len(out_lines), 0)
696             self.assertEqual(len(err_lines), 0)
697             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
698             self.assertEqual(report.return_code, 0)
699             report.done(Path("f2"), black.Changed.YES)
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 0)
702             self.assertEqual(
703                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
704             )
705             report.done(Path("f3"), black.Changed.CACHED)
706             self.assertEqual(len(out_lines), 0)
707             self.assertEqual(len(err_lines), 0)
708             self.assertEqual(
709                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
710             )
711             self.assertEqual(report.return_code, 0)
712             report.check = True
713             self.assertEqual(report.return_code, 1)
714             report.check = False
715             report.failed(Path("e1"), "boom")
716             self.assertEqual(len(out_lines), 0)
717             self.assertEqual(len(err_lines), 1)
718             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
719             self.assertEqual(
720                 unstyle(str(report)),
721                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
722                 " reformat.",
723             )
724             self.assertEqual(report.return_code, 123)
725             report.done(Path("f3"), black.Changed.YES)
726             self.assertEqual(len(out_lines), 0)
727             self.assertEqual(len(err_lines), 1)
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
731                 " reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.failed(Path("e2"), "boom")
735             self.assertEqual(len(out_lines), 0)
736             self.assertEqual(len(err_lines), 2)
737             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
738             self.assertEqual(
739                 unstyle(str(report)),
740                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
741                 " reformat.",
742             )
743             self.assertEqual(report.return_code, 123)
744             report.path_ignored(Path("wat"), "no match")
745             self.assertEqual(len(out_lines), 0)
746             self.assertEqual(len(err_lines), 2)
747             self.assertEqual(
748                 unstyle(str(report)),
749                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
750                 " reformat.",
751             )
752             self.assertEqual(report.return_code, 123)
753             report.done(Path("f4"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 2)
756             self.assertEqual(
757                 unstyle(str(report)),
758                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
759                 " reformat.",
760             )
761             self.assertEqual(report.return_code, 123)
762             report.check = True
763             self.assertEqual(
764                 unstyle(str(report)),
765                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
766                 " would fail to reformat.",
767             )
768             report.check = False
769             report.diff = True
770             self.assertEqual(
771                 unstyle(str(report)),
772                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
773                 " would fail to reformat.",
774             )
775
776     def test_report_normal(self) -> None:
777         report = black.Report()
778         out_lines = []
779         err_lines = []
780
781         def out(msg: str, **kwargs: Any) -> None:
782             out_lines.append(msg)
783
784         def err(msg: str, **kwargs: Any) -> None:
785             err_lines.append(msg)
786
787         with patch("black.out", out), patch("black.err", err):
788             report.done(Path("f1"), black.Changed.NO)
789             self.assertEqual(len(out_lines), 0)
790             self.assertEqual(len(err_lines), 0)
791             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
792             self.assertEqual(report.return_code, 0)
793             report.done(Path("f2"), black.Changed.YES)
794             self.assertEqual(len(out_lines), 1)
795             self.assertEqual(len(err_lines), 0)
796             self.assertEqual(out_lines[-1], "reformatted f2")
797             self.assertEqual(
798                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
799             )
800             report.done(Path("f3"), black.Changed.CACHED)
801             self.assertEqual(len(out_lines), 1)
802             self.assertEqual(len(err_lines), 0)
803             self.assertEqual(out_lines[-1], "reformatted f2")
804             self.assertEqual(
805                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
806             )
807             self.assertEqual(report.return_code, 0)
808             report.check = True
809             self.assertEqual(report.return_code, 1)
810             report.check = False
811             report.failed(Path("e1"), "boom")
812             self.assertEqual(len(out_lines), 1)
813             self.assertEqual(len(err_lines), 1)
814             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
818                 " reformat.",
819             )
820             self.assertEqual(report.return_code, 123)
821             report.done(Path("f3"), black.Changed.YES)
822             self.assertEqual(len(out_lines), 2)
823             self.assertEqual(len(err_lines), 1)
824             self.assertEqual(out_lines[-1], "reformatted f3")
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
828                 " reformat.",
829             )
830             self.assertEqual(report.return_code, 123)
831             report.failed(Path("e2"), "boom")
832             self.assertEqual(len(out_lines), 2)
833             self.assertEqual(len(err_lines), 2)
834             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
835             self.assertEqual(
836                 unstyle(str(report)),
837                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
838                 " reformat.",
839             )
840             self.assertEqual(report.return_code, 123)
841             report.path_ignored(Path("wat"), "no match")
842             self.assertEqual(len(out_lines), 2)
843             self.assertEqual(len(err_lines), 2)
844             self.assertEqual(
845                 unstyle(str(report)),
846                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
847                 " reformat.",
848             )
849             self.assertEqual(report.return_code, 123)
850             report.done(Path("f4"), black.Changed.NO)
851             self.assertEqual(len(out_lines), 2)
852             self.assertEqual(len(err_lines), 2)
853             self.assertEqual(
854                 unstyle(str(report)),
855                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
856                 " reformat.",
857             )
858             self.assertEqual(report.return_code, 123)
859             report.check = True
860             self.assertEqual(
861                 unstyle(str(report)),
862                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
863                 " would fail to reformat.",
864             )
865             report.check = False
866             report.diff = True
867             self.assertEqual(
868                 unstyle(str(report)),
869                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
870                 " would fail to reformat.",
871             )
872
873     def test_lib2to3_parse(self) -> None:
874         with self.assertRaises(black.InvalidInput):
875             black.lib2to3_parse("invalid syntax")
876
877         straddling = "x + y"
878         black.lib2to3_parse(straddling)
879         black.lib2to3_parse(straddling, {TargetVersion.PY27})
880         black.lib2to3_parse(straddling, {TargetVersion.PY36})
881         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
882
883         py2_only = "print x"
884         black.lib2to3_parse(py2_only)
885         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
886         with self.assertRaises(black.InvalidInput):
887             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
888         with self.assertRaises(black.InvalidInput):
889             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
890
891         py3_only = "exec(x, end=y)"
892         black.lib2to3_parse(py3_only)
893         with self.assertRaises(black.InvalidInput):
894             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
895         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
896         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
897
898     def test_get_features_used_decorator(self) -> None:
899         # Test the feature detection of new decorator syntax
900         # since this makes some test cases of test_get_features_used()
901         # fails if it fails, this is tested first so that a useful case
902         # is identified
903         simples, relaxed = read_data("decorators")
904         # skip explanation comments at the top of the file
905         for simple_test in simples.split("##")[1:]:
906             node = black.lib2to3_parse(simple_test)
907             decorator = str(node.children[0].children[0]).strip()
908             self.assertNotIn(
909                 Feature.RELAXED_DECORATORS,
910                 black.get_features_used(node),
911                 msg=(
912                     f"decorator '{decorator}' follows python<=3.8 syntax"
913                     "but is detected as 3.9+"
914                     # f"The full node is\n{node!r}"
915                 ),
916             )
917         # skip the '# output' comment at the top of the output part
918         for relaxed_test in relaxed.split("##")[1:]:
919             node = black.lib2to3_parse(relaxed_test)
920             decorator = str(node.children[0].children[0]).strip()
921             self.assertIn(
922                 Feature.RELAXED_DECORATORS,
923                 black.get_features_used(node),
924                 msg=(
925                     f"decorator '{decorator}' uses python3.9+ syntax"
926                     "but is detected as python<=3.8"
927                     # f"The full node is\n{node!r}"
928                 ),
929             )
930
931     def test_get_features_used(self) -> None:
932         node = black.lib2to3_parse("def f(*, arg): ...\n")
933         self.assertEqual(black.get_features_used(node), set())
934         node = black.lib2to3_parse("def f(*, arg,): ...\n")
935         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
936         node = black.lib2to3_parse("f(*arg,)\n")
937         self.assertEqual(
938             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
939         )
940         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
941         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
942         node = black.lib2to3_parse("123_456\n")
943         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
944         node = black.lib2to3_parse("123456\n")
945         self.assertEqual(black.get_features_used(node), set())
946         source, expected = read_data("function")
947         node = black.lib2to3_parse(source)
948         expected_features = {
949             Feature.TRAILING_COMMA_IN_CALL,
950             Feature.TRAILING_COMMA_IN_DEF,
951             Feature.F_STRINGS,
952         }
953         self.assertEqual(black.get_features_used(node), expected_features)
954         node = black.lib2to3_parse(expected)
955         self.assertEqual(black.get_features_used(node), expected_features)
956         source, expected = read_data("expression")
957         node = black.lib2to3_parse(source)
958         self.assertEqual(black.get_features_used(node), set())
959         node = black.lib2to3_parse(expected)
960         self.assertEqual(black.get_features_used(node), set())
961
962     def test_get_future_imports(self) -> None:
963         node = black.lib2to3_parse("\n")
964         self.assertEqual(set(), black.get_future_imports(node))
965         node = black.lib2to3_parse("from __future__ import black\n")
966         self.assertEqual({"black"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
968         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
969         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
970         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
971         node = black.lib2to3_parse(
972             "from __future__ import multiple\nfrom __future__ import imports\n"
973         )
974         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
975         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
976         self.assertEqual({"black"}, black.get_future_imports(node))
977         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
978         self.assertEqual({"black"}, black.get_future_imports(node))
979         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
980         self.assertEqual(set(), black.get_future_imports(node))
981         node = black.lib2to3_parse("from some.module import black\n")
982         self.assertEqual(set(), black.get_future_imports(node))
983         node = black.lib2to3_parse(
984             "from __future__ import unicode_literals as _unicode_literals"
985         )
986         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
987         node = black.lib2to3_parse(
988             "from __future__ import unicode_literals as _lol, print"
989         )
990         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
991
992     def test_debug_visitor(self) -> None:
993         source, _ = read_data("debug_visitor.py")
994         expected, _ = read_data("debug_visitor.out")
995         out_lines = []
996         err_lines = []
997
998         def out(msg: str, **kwargs: Any) -> None:
999             out_lines.append(msg)
1000
1001         def err(msg: str, **kwargs: Any) -> None:
1002             err_lines.append(msg)
1003
1004         with patch("black.out", out), patch("black.err", err):
1005             black.DebugVisitor.show(source)
1006         actual = "\n".join(out_lines) + "\n"
1007         log_name = ""
1008         if expected != actual:
1009             log_name = black.dump_to_file(*out_lines)
1010         self.assertEqual(
1011             expected,
1012             actual,
1013             f"AST print out is different. Actual version dumped to {log_name}",
1014         )
1015
1016     def test_format_file_contents(self) -> None:
1017         empty = ""
1018         mode = DEFAULT_MODE
1019         with self.assertRaises(black.NothingChanged):
1020             black.format_file_contents(empty, mode=mode, fast=False)
1021         just_nl = "\n"
1022         with self.assertRaises(black.NothingChanged):
1023             black.format_file_contents(just_nl, mode=mode, fast=False)
1024         same = "j = [1, 2, 3]\n"
1025         with self.assertRaises(black.NothingChanged):
1026             black.format_file_contents(same, mode=mode, fast=False)
1027         different = "j = [1,2,3]"
1028         expected = same
1029         actual = black.format_file_contents(different, mode=mode, fast=False)
1030         self.assertEqual(expected, actual)
1031         invalid = "return if you can"
1032         with self.assertRaises(black.InvalidInput) as e:
1033             black.format_file_contents(invalid, mode=mode, fast=False)
1034         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1035
1036     def test_endmarker(self) -> None:
1037         n = black.lib2to3_parse("\n")
1038         self.assertEqual(n.type, black.syms.file_input)
1039         self.assertEqual(len(n.children), 1)
1040         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1041
1042     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1043     def test_assertFormatEqual(self) -> None:
1044         out_lines = []
1045         err_lines = []
1046
1047         def out(msg: str, **kwargs: Any) -> None:
1048             out_lines.append(msg)
1049
1050         def err(msg: str, **kwargs: Any) -> None:
1051             err_lines.append(msg)
1052
1053         with patch("black.out", out), patch("black.err", err):
1054             with self.assertRaises(AssertionError):
1055                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1056
1057         out_str = "".join(out_lines)
1058         self.assertTrue("Expected tree:" in out_str)
1059         self.assertTrue("Actual tree:" in out_str)
1060         self.assertEqual("".join(err_lines), "")
1061
1062     def test_cache_broken_file(self) -> None:
1063         mode = DEFAULT_MODE
1064         with cache_dir() as workspace:
1065             cache_file = black.get_cache_file(mode)
1066             with cache_file.open("w") as fobj:
1067                 fobj.write("this is not a pickle")
1068             self.assertEqual(black.read_cache(mode), {})
1069             src = (workspace / "test.py").resolve()
1070             with src.open("w") as fobj:
1071                 fobj.write("print('hello')")
1072             self.invokeBlack([str(src)])
1073             cache = black.read_cache(mode)
1074             self.assertIn(str(src), cache)
1075
1076     def test_cache_single_file_already_cached(self) -> None:
1077         mode = DEFAULT_MODE
1078         with cache_dir() as workspace:
1079             src = (workspace / "test.py").resolve()
1080             with src.open("w") as fobj:
1081                 fobj.write("print('hello')")
1082             black.write_cache({}, [src], mode)
1083             self.invokeBlack([str(src)])
1084             with src.open("r") as fobj:
1085                 self.assertEqual(fobj.read(), "print('hello')")
1086
1087     @event_loop()
1088     def test_cache_multiple_files(self) -> None:
1089         mode = DEFAULT_MODE
1090         with cache_dir() as workspace, patch(
1091             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1092         ):
1093             one = (workspace / "one.py").resolve()
1094             with one.open("w") as fobj:
1095                 fobj.write("print('hello')")
1096             two = (workspace / "two.py").resolve()
1097             with two.open("w") as fobj:
1098                 fobj.write("print('hello')")
1099             black.write_cache({}, [one], mode)
1100             self.invokeBlack([str(workspace)])
1101             with one.open("r") as fobj:
1102                 self.assertEqual(fobj.read(), "print('hello')")
1103             with two.open("r") as fobj:
1104                 self.assertEqual(fobj.read(), 'print("hello")\n')
1105             cache = black.read_cache(mode)
1106             self.assertIn(str(one), cache)
1107             self.assertIn(str(two), cache)
1108
1109     def test_no_cache_when_writeback_diff(self) -> None:
1110         mode = DEFAULT_MODE
1111         with cache_dir() as workspace:
1112             src = (workspace / "test.py").resolve()
1113             with src.open("w") as fobj:
1114                 fobj.write("print('hello')")
1115             with patch("black.read_cache") as read_cache, patch(
1116                 "black.write_cache"
1117             ) as write_cache:
1118                 self.invokeBlack([str(src), "--diff"])
1119                 cache_file = black.get_cache_file(mode)
1120                 self.assertFalse(cache_file.exists())
1121                 write_cache.assert_not_called()
1122                 read_cache.assert_not_called()
1123
1124     def test_no_cache_when_writeback_color_diff(self) -> None:
1125         mode = DEFAULT_MODE
1126         with cache_dir() as workspace:
1127             src = (workspace / "test.py").resolve()
1128             with src.open("w") as fobj:
1129                 fobj.write("print('hello')")
1130             with patch("black.read_cache") as read_cache, patch(
1131                 "black.write_cache"
1132             ) as write_cache:
1133                 self.invokeBlack([str(src), "--diff", "--color"])
1134                 cache_file = black.get_cache_file(mode)
1135                 self.assertFalse(cache_file.exists())
1136                 write_cache.assert_not_called()
1137                 read_cache.assert_not_called()
1138
1139     @event_loop()
1140     def test_output_locking_when_writeback_diff(self) -> None:
1141         with cache_dir() as workspace:
1142             for tag in range(0, 4):
1143                 src = (workspace / f"test{tag}.py").resolve()
1144                 with src.open("w") as fobj:
1145                     fobj.write("print('hello')")
1146             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1147                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1148                 # this isn't quite doing what we want, but if it _isn't_
1149                 # called then we cannot be using the lock it provides
1150                 mgr.assert_called()
1151
1152     @event_loop()
1153     def test_output_locking_when_writeback_color_diff(self) -> None:
1154         with cache_dir() as workspace:
1155             for tag in range(0, 4):
1156                 src = (workspace / f"test{tag}.py").resolve()
1157                 with src.open("w") as fobj:
1158                     fobj.write("print('hello')")
1159             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1160                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1161                 # this isn't quite doing what we want, but if it _isn't_
1162                 # called then we cannot be using the lock it provides
1163                 mgr.assert_called()
1164
1165     def test_no_cache_when_stdin(self) -> None:
1166         mode = DEFAULT_MODE
1167         with cache_dir():
1168             result = CliRunner().invoke(
1169                 black.main, ["-"], input=BytesIO(b"print('hello')")
1170             )
1171             self.assertEqual(result.exit_code, 0)
1172             cache_file = black.get_cache_file(mode)
1173             self.assertFalse(cache_file.exists())
1174
1175     def test_read_cache_no_cachefile(self) -> None:
1176         mode = DEFAULT_MODE
1177         with cache_dir():
1178             self.assertEqual(black.read_cache(mode), {})
1179
1180     def test_write_cache_read_cache(self) -> None:
1181         mode = DEFAULT_MODE
1182         with cache_dir() as workspace:
1183             src = (workspace / "test.py").resolve()
1184             src.touch()
1185             black.write_cache({}, [src], mode)
1186             cache = black.read_cache(mode)
1187             self.assertIn(str(src), cache)
1188             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1189
1190     def test_filter_cached(self) -> None:
1191         with TemporaryDirectory() as workspace:
1192             path = Path(workspace)
1193             uncached = (path / "uncached").resolve()
1194             cached = (path / "cached").resolve()
1195             cached_but_changed = (path / "changed").resolve()
1196             uncached.touch()
1197             cached.touch()
1198             cached_but_changed.touch()
1199             cache = {
1200                 str(cached): black.get_cache_info(cached),
1201                 str(cached_but_changed): (0.0, 0),
1202             }
1203             todo, done = black.filter_cached(
1204                 cache, {uncached, cached, cached_but_changed}
1205             )
1206             self.assertEqual(todo, {uncached, cached_but_changed})
1207             self.assertEqual(done, {cached})
1208
1209     def test_write_cache_creates_directory_if_needed(self) -> None:
1210         mode = DEFAULT_MODE
1211         with cache_dir(exists=False) as workspace:
1212             self.assertFalse(workspace.exists())
1213             black.write_cache({}, [], mode)
1214             self.assertTrue(workspace.exists())
1215
1216     @event_loop()
1217     def test_failed_formatting_does_not_get_cached(self) -> None:
1218         mode = DEFAULT_MODE
1219         with cache_dir() as workspace, patch(
1220             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1221         ):
1222             failing = (workspace / "failing.py").resolve()
1223             with failing.open("w") as fobj:
1224                 fobj.write("not actually python")
1225             clean = (workspace / "clean.py").resolve()
1226             with clean.open("w") as fobj:
1227                 fobj.write('print("hello")\n')
1228             self.invokeBlack([str(workspace)], exit_code=123)
1229             cache = black.read_cache(mode)
1230             self.assertNotIn(str(failing), cache)
1231             self.assertIn(str(clean), cache)
1232
1233     def test_write_cache_write_fail(self) -> None:
1234         mode = DEFAULT_MODE
1235         with cache_dir(), patch.object(Path, "open") as mock:
1236             mock.side_effect = OSError
1237             black.write_cache({}, [], mode)
1238
1239     @event_loop()
1240     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1241     def test_works_in_mono_process_only_environment(self) -> None:
1242         with cache_dir() as workspace:
1243             for f in [
1244                 (workspace / "one.py").resolve(),
1245                 (workspace / "two.py").resolve(),
1246             ]:
1247                 f.write_text('print("hello")\n')
1248             self.invokeBlack([str(workspace)])
1249
1250     @event_loop()
1251     def test_check_diff_use_together(self) -> None:
1252         with cache_dir():
1253             # Files which will be reformatted.
1254             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1255             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1256             # Files which will not be reformatted.
1257             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1258             self.invokeBlack([str(src2), "--diff", "--check"])
1259             # Multi file command.
1260             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1261
1262     def test_no_files(self) -> None:
1263         with cache_dir():
1264             # Without an argument, black exits with error code 0.
1265             self.invokeBlack([])
1266
1267     def test_broken_symlink(self) -> None:
1268         with cache_dir() as workspace:
1269             symlink = workspace / "broken_link.py"
1270             try:
1271                 symlink.symlink_to("nonexistent.py")
1272             except OSError as e:
1273                 self.skipTest(f"Can't create symlinks: {e}")
1274             self.invokeBlack([str(workspace.resolve())])
1275
1276     def test_read_cache_line_lengths(self) -> None:
1277         mode = DEFAULT_MODE
1278         short_mode = replace(DEFAULT_MODE, line_length=1)
1279         with cache_dir() as workspace:
1280             path = (workspace / "file.py").resolve()
1281             path.touch()
1282             black.write_cache({}, [path], mode)
1283             one = black.read_cache(mode)
1284             self.assertIn(str(path), one)
1285             two = black.read_cache(short_mode)
1286             self.assertNotIn(str(path), two)
1287
1288     def test_single_file_force_pyi(self) -> None:
1289         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1290         contents, expected = read_data("force_pyi")
1291         with cache_dir() as workspace:
1292             path = (workspace / "file.py").resolve()
1293             with open(path, "w") as fh:
1294                 fh.write(contents)
1295             self.invokeBlack([str(path), "--pyi"])
1296             with open(path, "r") as fh:
1297                 actual = fh.read()
1298             # verify cache with --pyi is separate
1299             pyi_cache = black.read_cache(pyi_mode)
1300             self.assertIn(str(path), pyi_cache)
1301             normal_cache = black.read_cache(DEFAULT_MODE)
1302             self.assertNotIn(str(path), normal_cache)
1303         self.assertFormatEqual(expected, actual)
1304         black.assert_equivalent(contents, actual)
1305         black.assert_stable(contents, actual, pyi_mode)
1306
1307     @event_loop()
1308     def test_multi_file_force_pyi(self) -> None:
1309         reg_mode = DEFAULT_MODE
1310         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1311         contents, expected = read_data("force_pyi")
1312         with cache_dir() as workspace:
1313             paths = [
1314                 (workspace / "file1.py").resolve(),
1315                 (workspace / "file2.py").resolve(),
1316             ]
1317             for path in paths:
1318                 with open(path, "w") as fh:
1319                     fh.write(contents)
1320             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1321             for path in paths:
1322                 with open(path, "r") as fh:
1323                     actual = fh.read()
1324                 self.assertEqual(actual, expected)
1325             # verify cache with --pyi is separate
1326             pyi_cache = black.read_cache(pyi_mode)
1327             normal_cache = black.read_cache(reg_mode)
1328             for path in paths:
1329                 self.assertIn(str(path), pyi_cache)
1330                 self.assertNotIn(str(path), normal_cache)
1331
1332     def test_pipe_force_pyi(self) -> None:
1333         source, expected = read_data("force_pyi")
1334         result = CliRunner().invoke(
1335             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1336         )
1337         self.assertEqual(result.exit_code, 0)
1338         actual = result.output
1339         self.assertFormatEqual(actual, expected)
1340
1341     def test_single_file_force_py36(self) -> None:
1342         reg_mode = DEFAULT_MODE
1343         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1344         source, expected = read_data("force_py36")
1345         with cache_dir() as workspace:
1346             path = (workspace / "file.py").resolve()
1347             with open(path, "w") as fh:
1348                 fh.write(source)
1349             self.invokeBlack([str(path), *PY36_ARGS])
1350             with open(path, "r") as fh:
1351                 actual = fh.read()
1352             # verify cache with --target-version is separate
1353             py36_cache = black.read_cache(py36_mode)
1354             self.assertIn(str(path), py36_cache)
1355             normal_cache = black.read_cache(reg_mode)
1356             self.assertNotIn(str(path), normal_cache)
1357         self.assertEqual(actual, expected)
1358
1359     @event_loop()
1360     def test_multi_file_force_py36(self) -> None:
1361         reg_mode = DEFAULT_MODE
1362         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1363         source, expected = read_data("force_py36")
1364         with cache_dir() as workspace:
1365             paths = [
1366                 (workspace / "file1.py").resolve(),
1367                 (workspace / "file2.py").resolve(),
1368             ]
1369             for path in paths:
1370                 with open(path, "w") as fh:
1371                     fh.write(source)
1372             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1373             for path in paths:
1374                 with open(path, "r") as fh:
1375                     actual = fh.read()
1376                 self.assertEqual(actual, expected)
1377             # verify cache with --target-version is separate
1378             pyi_cache = black.read_cache(py36_mode)
1379             normal_cache = black.read_cache(reg_mode)
1380             for path in paths:
1381                 self.assertIn(str(path), pyi_cache)
1382                 self.assertNotIn(str(path), normal_cache)
1383
1384     def test_pipe_force_py36(self) -> None:
1385         source, expected = read_data("force_py36")
1386         result = CliRunner().invoke(
1387             black.main,
1388             ["-", "-q", "--target-version=py36"],
1389             input=BytesIO(source.encode("utf8")),
1390         )
1391         self.assertEqual(result.exit_code, 0)
1392         actual = result.output
1393         self.assertFormatEqual(actual, expected)
1394
1395     def test_include_exclude(self) -> None:
1396         path = THIS_DIR / "data" / "include_exclude_tests"
1397         include = re.compile(r"\.pyi?$")
1398         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1399         report = black.Report()
1400         gitignore = PathSpec.from_lines("gitwildmatch", [])
1401         sources: List[Path] = []
1402         expected = [
1403             Path(path / "b/dont_exclude/a.py"),
1404             Path(path / "b/dont_exclude/a.pyi"),
1405         ]
1406         this_abs = THIS_DIR.resolve()
1407         sources.extend(
1408             black.gen_python_files(
1409                 path.iterdir(),
1410                 this_abs,
1411                 include,
1412                 exclude,
1413                 None,
1414                 None,
1415                 report,
1416                 gitignore,
1417             )
1418         )
1419         self.assertEqual(sorted(expected), sorted(sources))
1420
1421     def test_gitingore_used_as_default(self) -> None:
1422         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1423         include = re.compile(r"\.pyi?$")
1424         extend_exclude = re.compile(r"/exclude/")
1425         src = str(path / "b/")
1426         report = black.Report()
1427         expected: List[Path] = [
1428             path / "b/.definitely_exclude/a.py",
1429             path / "b/.definitely_exclude/a.pyi",
1430         ]
1431         sources = list(
1432             black.get_sources(
1433                 ctx=FakeContext(),
1434                 src=(src,),
1435                 quiet=True,
1436                 verbose=False,
1437                 include=include,
1438                 exclude=None,
1439                 extend_exclude=extend_exclude,
1440                 force_exclude=None,
1441                 report=report,
1442                 stdin_filename=None,
1443             )
1444         )
1445         self.assertEqual(sorted(expected), sorted(sources))
1446
1447     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1448     def test_exclude_for_issue_1572(self) -> None:
1449         # Exclude shouldn't touch files that were explicitly given to Black through the
1450         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1451         # https://github.com/psf/black/issues/1572
1452         path = THIS_DIR / "data" / "include_exclude_tests"
1453         include = ""
1454         exclude = r"/exclude/|a\.py"
1455         src = str(path / "b/exclude/a.py")
1456         report = black.Report()
1457         expected = [Path(path / "b/exclude/a.py")]
1458         sources = list(
1459             black.get_sources(
1460                 ctx=FakeContext(),
1461                 src=(src,),
1462                 quiet=True,
1463                 verbose=False,
1464                 include=re.compile(include),
1465                 exclude=re.compile(exclude),
1466                 extend_exclude=None,
1467                 force_exclude=None,
1468                 report=report,
1469                 stdin_filename=None,
1470             )
1471         )
1472         self.assertEqual(sorted(expected), sorted(sources))
1473
1474     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1475     def test_get_sources_with_stdin(self) -> None:
1476         include = ""
1477         exclude = r"/exclude/|a\.py"
1478         src = "-"
1479         report = black.Report()
1480         expected = [Path("-")]
1481         sources = list(
1482             black.get_sources(
1483                 ctx=FakeContext(),
1484                 src=(src,),
1485                 quiet=True,
1486                 verbose=False,
1487                 include=re.compile(include),
1488                 exclude=re.compile(exclude),
1489                 extend_exclude=None,
1490                 force_exclude=None,
1491                 report=report,
1492                 stdin_filename=None,
1493             )
1494         )
1495         self.assertEqual(sorted(expected), sorted(sources))
1496
1497     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1498     def test_get_sources_with_stdin_filename(self) -> None:
1499         include = ""
1500         exclude = r"/exclude/|a\.py"
1501         src = "-"
1502         report = black.Report()
1503         stdin_filename = str(THIS_DIR / "data/collections.py")
1504         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1505         sources = list(
1506             black.get_sources(
1507                 ctx=FakeContext(),
1508                 src=(src,),
1509                 quiet=True,
1510                 verbose=False,
1511                 include=re.compile(include),
1512                 exclude=re.compile(exclude),
1513                 extend_exclude=None,
1514                 force_exclude=None,
1515                 report=report,
1516                 stdin_filename=stdin_filename,
1517             )
1518         )
1519         self.assertEqual(sorted(expected), sorted(sources))
1520
1521     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1522     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1523         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1524         # file being passed directly. This is the same as
1525         # test_exclude_for_issue_1572
1526         path = THIS_DIR / "data" / "include_exclude_tests"
1527         include = ""
1528         exclude = r"/exclude/|a\.py"
1529         src = "-"
1530         report = black.Report()
1531         stdin_filename = str(path / "b/exclude/a.py")
1532         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1533         sources = list(
1534             black.get_sources(
1535                 ctx=FakeContext(),
1536                 src=(src,),
1537                 quiet=True,
1538                 verbose=False,
1539                 include=re.compile(include),
1540                 exclude=re.compile(exclude),
1541                 extend_exclude=None,
1542                 force_exclude=None,
1543                 report=report,
1544                 stdin_filename=stdin_filename,
1545             )
1546         )
1547         self.assertEqual(sorted(expected), sorted(sources))
1548
1549     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1550     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1551         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1552         # file being passed directly. This is the same as
1553         # test_exclude_for_issue_1572
1554         path = THIS_DIR / "data" / "include_exclude_tests"
1555         include = ""
1556         extend_exclude = r"/exclude/|a\.py"
1557         src = "-"
1558         report = black.Report()
1559         stdin_filename = str(path / "b/exclude/a.py")
1560         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1561         sources = list(
1562             black.get_sources(
1563                 ctx=FakeContext(),
1564                 src=(src,),
1565                 quiet=True,
1566                 verbose=False,
1567                 include=re.compile(include),
1568                 exclude=re.compile(""),
1569                 extend_exclude=re.compile(extend_exclude),
1570                 force_exclude=None,
1571                 report=report,
1572                 stdin_filename=stdin_filename,
1573             )
1574         )
1575         self.assertEqual(sorted(expected), sorted(sources))
1576
1577     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1578     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1579         # Force exclude should exclude the file when passing it through
1580         # stdin_filename
1581         path = THIS_DIR / "data" / "include_exclude_tests"
1582         include = ""
1583         force_exclude = r"/exclude/|a\.py"
1584         src = "-"
1585         report = black.Report()
1586         stdin_filename = str(path / "b/exclude/a.py")
1587         sources = list(
1588             black.get_sources(
1589                 ctx=FakeContext(),
1590                 src=(src,),
1591                 quiet=True,
1592                 verbose=False,
1593                 include=re.compile(include),
1594                 exclude=re.compile(""),
1595                 extend_exclude=None,
1596                 force_exclude=re.compile(force_exclude),
1597                 report=report,
1598                 stdin_filename=stdin_filename,
1599             )
1600         )
1601         self.assertEqual([], sorted(sources))
1602
1603     def test_reformat_one_with_stdin(self) -> None:
1604         with patch(
1605             "black.format_stdin_to_stdout",
1606             return_value=lambda *args, **kwargs: black.Changed.YES,
1607         ) as fsts:
1608             report = MagicMock()
1609             path = Path("-")
1610             black.reformat_one(
1611                 path,
1612                 fast=True,
1613                 write_back=black.WriteBack.YES,
1614                 mode=DEFAULT_MODE,
1615                 report=report,
1616             )
1617             fsts.assert_called_once()
1618             report.done.assert_called_with(path, black.Changed.YES)
1619
1620     def test_reformat_one_with_stdin_filename(self) -> None:
1621         with patch(
1622             "black.format_stdin_to_stdout",
1623             return_value=lambda *args, **kwargs: black.Changed.YES,
1624         ) as fsts:
1625             report = MagicMock()
1626             p = "foo.py"
1627             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1628             expected = Path(p)
1629             black.reformat_one(
1630                 path,
1631                 fast=True,
1632                 write_back=black.WriteBack.YES,
1633                 mode=DEFAULT_MODE,
1634                 report=report,
1635             )
1636             fsts.assert_called_once_with(
1637                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1638             )
1639             # __BLACK_STDIN_FILENAME__ should have been stripped
1640             report.done.assert_called_with(expected, black.Changed.YES)
1641
1642     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1643         with patch(
1644             "black.format_stdin_to_stdout",
1645             return_value=lambda *args, **kwargs: black.Changed.YES,
1646         ) as fsts:
1647             report = MagicMock()
1648             p = "foo.pyi"
1649             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1650             expected = Path(p)
1651             black.reformat_one(
1652                 path,
1653                 fast=True,
1654                 write_back=black.WriteBack.YES,
1655                 mode=DEFAULT_MODE,
1656                 report=report,
1657             )
1658             fsts.assert_called_once_with(
1659                 fast=True,
1660                 write_back=black.WriteBack.YES,
1661                 mode=replace(DEFAULT_MODE, is_pyi=True),
1662             )
1663             # __BLACK_STDIN_FILENAME__ should have been stripped
1664             report.done.assert_called_with(expected, black.Changed.YES)
1665
1666     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1667         with patch(
1668             "black.format_stdin_to_stdout",
1669             return_value=lambda *args, **kwargs: black.Changed.YES,
1670         ) as fsts:
1671             report = MagicMock()
1672             # Even with an existing file, since we are forcing stdin, black
1673             # should output to stdout and not modify the file inplace
1674             p = Path(str(THIS_DIR / "data/collections.py"))
1675             # Make sure is_file actually returns True
1676             self.assertTrue(p.is_file())
1677             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1678             expected = Path(p)
1679             black.reformat_one(
1680                 path,
1681                 fast=True,
1682                 write_back=black.WriteBack.YES,
1683                 mode=DEFAULT_MODE,
1684                 report=report,
1685             )
1686             fsts.assert_called_once()
1687             # __BLACK_STDIN_FILENAME__ should have been stripped
1688             report.done.assert_called_with(expected, black.Changed.YES)
1689
1690     def test_gitignore_exclude(self) -> None:
1691         path = THIS_DIR / "data" / "include_exclude_tests"
1692         include = re.compile(r"\.pyi?$")
1693         exclude = re.compile(r"")
1694         report = black.Report()
1695         gitignore = PathSpec.from_lines(
1696             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1697         )
1698         sources: List[Path] = []
1699         expected = [
1700             Path(path / "b/dont_exclude/a.py"),
1701             Path(path / "b/dont_exclude/a.pyi"),
1702         ]
1703         this_abs = THIS_DIR.resolve()
1704         sources.extend(
1705             black.gen_python_files(
1706                 path.iterdir(),
1707                 this_abs,
1708                 include,
1709                 exclude,
1710                 None,
1711                 None,
1712                 report,
1713                 gitignore,
1714             )
1715         )
1716         self.assertEqual(sorted(expected), sorted(sources))
1717
1718     def test_empty_include(self) -> None:
1719         path = THIS_DIR / "data" / "include_exclude_tests"
1720         report = black.Report()
1721         gitignore = PathSpec.from_lines("gitwildmatch", [])
1722         empty = re.compile(r"")
1723         sources: List[Path] = []
1724         expected = [
1725             Path(path / "b/exclude/a.pie"),
1726             Path(path / "b/exclude/a.py"),
1727             Path(path / "b/exclude/a.pyi"),
1728             Path(path / "b/dont_exclude/a.pie"),
1729             Path(path / "b/dont_exclude/a.py"),
1730             Path(path / "b/dont_exclude/a.pyi"),
1731             Path(path / "b/.definitely_exclude/a.pie"),
1732             Path(path / "b/.definitely_exclude/a.py"),
1733             Path(path / "b/.definitely_exclude/a.pyi"),
1734             Path(path / ".gitignore"),
1735             Path(path / "pyproject.toml"),
1736         ]
1737         this_abs = THIS_DIR.resolve()
1738         sources.extend(
1739             black.gen_python_files(
1740                 path.iterdir(),
1741                 this_abs,
1742                 empty,
1743                 re.compile(black.DEFAULT_EXCLUDES),
1744                 None,
1745                 None,
1746                 report,
1747                 gitignore,
1748             )
1749         )
1750         self.assertEqual(sorted(expected), sorted(sources))
1751
1752     def test_extend_exclude(self) -> None:
1753         path = THIS_DIR / "data" / "include_exclude_tests"
1754         report = black.Report()
1755         gitignore = PathSpec.from_lines("gitwildmatch", [])
1756         sources: List[Path] = []
1757         expected = [
1758             Path(path / "b/exclude/a.py"),
1759             Path(path / "b/dont_exclude/a.py"),
1760         ]
1761         this_abs = THIS_DIR.resolve()
1762         sources.extend(
1763             black.gen_python_files(
1764                 path.iterdir(),
1765                 this_abs,
1766                 re.compile(black.DEFAULT_INCLUDES),
1767                 re.compile(r"\.pyi$"),
1768                 re.compile(r"\.definitely_exclude"),
1769                 None,
1770                 report,
1771                 gitignore,
1772             )
1773         )
1774         self.assertEqual(sorted(expected), sorted(sources))
1775
1776     def test_invalid_cli_regex(self) -> None:
1777         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1778             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1779
1780     def test_preserves_line_endings(self) -> None:
1781         with TemporaryDirectory() as workspace:
1782             test_file = Path(workspace) / "test.py"
1783             for nl in ["\n", "\r\n"]:
1784                 contents = nl.join(["def f(  ):", "    pass"])
1785                 test_file.write_bytes(contents.encode())
1786                 ff(test_file, write_back=black.WriteBack.YES)
1787                 updated_contents: bytes = test_file.read_bytes()
1788                 self.assertIn(nl.encode(), updated_contents)
1789                 if nl == "\n":
1790                     self.assertNotIn(b"\r\n", updated_contents)
1791
1792     def test_preserves_line_endings_via_stdin(self) -> None:
1793         for nl in ["\n", "\r\n"]:
1794             contents = nl.join(["def f(  ):", "    pass"])
1795             runner = BlackRunner()
1796             result = runner.invoke(
1797                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1798             )
1799             self.assertEqual(result.exit_code, 0)
1800             output = runner.stdout_bytes
1801             self.assertIn(nl.encode("utf8"), output)
1802             if nl == "\n":
1803                 self.assertNotIn(b"\r\n", output)
1804
1805     def test_assert_equivalent_different_asts(self) -> None:
1806         with self.assertRaises(AssertionError):
1807             black.assert_equivalent("{}", "None")
1808
1809     def test_symlink_out_of_root_directory(self) -> None:
1810         path = MagicMock()
1811         root = THIS_DIR.resolve()
1812         child = MagicMock()
1813         include = re.compile(black.DEFAULT_INCLUDES)
1814         exclude = re.compile(black.DEFAULT_EXCLUDES)
1815         report = black.Report()
1816         gitignore = PathSpec.from_lines("gitwildmatch", [])
1817         # `child` should behave like a symlink which resolved path is clearly
1818         # outside of the `root` directory.
1819         path.iterdir.return_value = [child]
1820         child.resolve.return_value = Path("/a/b/c")
1821         child.as_posix.return_value = "/a/b/c"
1822         child.is_symlink.return_value = True
1823         try:
1824             list(
1825                 black.gen_python_files(
1826                     path.iterdir(),
1827                     root,
1828                     include,
1829                     exclude,
1830                     None,
1831                     None,
1832                     report,
1833                     gitignore,
1834                 )
1835             )
1836         except ValueError as ve:
1837             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1838         path.iterdir.assert_called_once()
1839         child.resolve.assert_called_once()
1840         child.is_symlink.assert_called_once()
1841         # `child` should behave like a strange file which resolved path is clearly
1842         # outside of the `root` directory.
1843         child.is_symlink.return_value = False
1844         with self.assertRaises(ValueError):
1845             list(
1846                 black.gen_python_files(
1847                     path.iterdir(),
1848                     root,
1849                     include,
1850                     exclude,
1851                     None,
1852                     None,
1853                     report,
1854                     gitignore,
1855                 )
1856             )
1857         path.iterdir.assert_called()
1858         self.assertEqual(path.iterdir.call_count, 2)
1859         child.resolve.assert_called()
1860         self.assertEqual(child.resolve.call_count, 2)
1861         child.is_symlink.assert_called()
1862         self.assertEqual(child.is_symlink.call_count, 2)
1863
1864     def test_shhh_click(self) -> None:
1865         try:
1866             from click import _unicodefun  # type: ignore
1867         except ModuleNotFoundError:
1868             self.skipTest("Incompatible Click version")
1869         if not hasattr(_unicodefun, "_verify_python3_env"):
1870             self.skipTest("Incompatible Click version")
1871         # First, let's see if Click is crashing with a preferred ASCII charset.
1872         with patch("locale.getpreferredencoding") as gpe:
1873             gpe.return_value = "ASCII"
1874             with self.assertRaises(RuntimeError):
1875                 _unicodefun._verify_python3_env()
1876         # Now, let's silence Click...
1877         black.patch_click()
1878         # ...and confirm it's silent.
1879         with patch("locale.getpreferredencoding") as gpe:
1880             gpe.return_value = "ASCII"
1881             try:
1882                 _unicodefun._verify_python3_env()
1883             except RuntimeError as re:
1884                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1885
1886     def test_root_logger_not_used_directly(self) -> None:
1887         def fail(*args: Any, **kwargs: Any) -> None:
1888             self.fail("Record created with root logger")
1889
1890         with patch.multiple(
1891             logging.root,
1892             debug=fail,
1893             info=fail,
1894             warning=fail,
1895             error=fail,
1896             critical=fail,
1897             log=fail,
1898         ):
1899             ff(THIS_FILE)
1900
1901     def test_invalid_config_return_code(self) -> None:
1902         tmp_file = Path(black.dump_to_file())
1903         try:
1904             tmp_config = Path(black.dump_to_file())
1905             tmp_config.unlink()
1906             args = ["--config", str(tmp_config), str(tmp_file)]
1907             self.invokeBlack(args, exit_code=2, ignore_config=False)
1908         finally:
1909             tmp_file.unlink()
1910
1911     def test_parse_pyproject_toml(self) -> None:
1912         test_toml_file = THIS_DIR / "test.toml"
1913         config = black.parse_pyproject_toml(str(test_toml_file))
1914         self.assertEqual(config["verbose"], 1)
1915         self.assertEqual(config["check"], "no")
1916         self.assertEqual(config["diff"], "y")
1917         self.assertEqual(config["color"], True)
1918         self.assertEqual(config["line_length"], 79)
1919         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1920         self.assertEqual(config["exclude"], r"\.pyi?$")
1921         self.assertEqual(config["include"], r"\.py?$")
1922
1923     def test_read_pyproject_toml(self) -> None:
1924         test_toml_file = THIS_DIR / "test.toml"
1925         fake_ctx = FakeContext()
1926         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1927         config = fake_ctx.default_map
1928         self.assertEqual(config["verbose"], "1")
1929         self.assertEqual(config["check"], "no")
1930         self.assertEqual(config["diff"], "y")
1931         self.assertEqual(config["color"], "True")
1932         self.assertEqual(config["line_length"], "79")
1933         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1934         self.assertEqual(config["exclude"], r"\.pyi?$")
1935         self.assertEqual(config["include"], r"\.py?$")
1936
1937     def test_find_project_root(self) -> None:
1938         with TemporaryDirectory() as workspace:
1939             root = Path(workspace)
1940             test_dir = root / "test"
1941             test_dir.mkdir()
1942
1943             src_dir = root / "src"
1944             src_dir.mkdir()
1945
1946             root_pyproject = root / "pyproject.toml"
1947             root_pyproject.touch()
1948             src_pyproject = src_dir / "pyproject.toml"
1949             src_pyproject.touch()
1950             src_python = src_dir / "foo.py"
1951             src_python.touch()
1952
1953             self.assertEqual(
1954                 black.find_project_root((src_dir, test_dir)), root.resolve()
1955             )
1956             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1957             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1958
1959     @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
1960     def test_find_user_pyproject_toml_linux(self) -> None:
1961         if system() == "Windows":
1962             return
1963
1964         # Test if XDG_CONFIG_HOME is checked
1965         with TemporaryDirectory() as workspace:
1966             tmp_user_config = Path(workspace) / "black"
1967             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1968                 self.assertEqual(
1969                     black.find_user_pyproject_toml(), tmp_user_config.resolve()
1970                 )
1971
1972         # Test fallback for XDG_CONFIG_HOME
1973         with patch.dict("os.environ"):
1974             os.environ.pop("XDG_CONFIG_HOME", None)
1975             fallback_user_config = Path("~/.config").expanduser() / "black"
1976             self.assertEqual(
1977                 black.find_user_pyproject_toml(), fallback_user_config.resolve()
1978             )
1979
1980     def test_find_user_pyproject_toml_windows(self) -> None:
1981         if system() != "Windows":
1982             return
1983
1984         user_config_path = Path.home() / ".black"
1985         self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
1986
1987     def test_bpo_33660_workaround(self) -> None:
1988         if system() == "Windows":
1989             return
1990
1991         # https://bugs.python.org/issue33660
1992
1993         old_cwd = Path.cwd()
1994         try:
1995             root = Path("/")
1996             os.chdir(str(root))
1997             path = Path("workspace") / "project"
1998             report = black.Report(verbose=True)
1999             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2000             self.assertEqual(normalized_path, "workspace/project")
2001         finally:
2002             os.chdir(str(old_cwd))
2003
2004     def test_newline_comment_interaction(self) -> None:
2005         source = "class A:\\\r\n# type: ignore\n pass\n"
2006         output = black.format_str(source, mode=DEFAULT_MODE)
2007         black.assert_stable(source, output, mode=DEFAULT_MODE)
2008
2009     def test_bpo_2142_workaround(self) -> None:
2010
2011         # https://bugs.python.org/issue2142
2012
2013         source, _ = read_data("missing_final_newline.py")
2014         # read_data adds a trailing newline
2015         source = source.rstrip()
2016         expected, _ = read_data("missing_final_newline.diff")
2017         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2018         diff_header = re.compile(
2019             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2020             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2021         )
2022         try:
2023             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2024             self.assertEqual(result.exit_code, 0)
2025         finally:
2026             os.unlink(tmp_file)
2027         actual = result.output
2028         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2029         self.assertEqual(actual, expected)
2030
2031     @pytest.mark.python2
2032     def test_docstring_reformat_for_py27(self) -> None:
2033         """
2034         Check that stripping trailing whitespace from Python 2 docstrings
2035         doesn't trigger a "not equivalent to source" error
2036         """
2037         source = (
2038             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2039         )
2040         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2041
2042         result = CliRunner().invoke(
2043             black.main,
2044             ["-", "-q", "--target-version=py27"],
2045             input=BytesIO(source),
2046         )
2047
2048         self.assertEqual(result.exit_code, 0)
2049         actual = result.output
2050         self.assertFormatEqual(actual, expected)
2051
2052
2053 with open(black.__file__, "r", encoding="utf-8") as _bf:
2054     black_source_lines = _bf.readlines()
2055
2056
2057 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2058     """Show function calls `from black/__init__.py` as they happen.
2059
2060     Register this with `sys.settrace()` in a test you're debugging.
2061     """
2062     if event != "call":
2063         return tracefunc
2064
2065     stack = len(inspect.stack()) - 19
2066     stack *= 2
2067     filename = frame.f_code.co_filename
2068     lineno = frame.f_lineno
2069     func_sig_lineno = lineno - 1
2070     funcname = black_source_lines[func_sig_lineno].strip()
2071     while funcname.startswith("@"):
2072         func_sig_lineno += 1
2073         funcname = black_source_lines[func_sig_lineno].strip()
2074     if "black/__init__.py" in filename:
2075         print(f"{' ' * stack}{lineno}:{funcname}")
2076     return tracefunc
2077
2078
2079 if __name__ == "__main__":
2080     unittest.main(module="test_black")