]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Simplify GitHub Action entrypoint (#2119)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import pytest
28 import unittest
29 from unittest.mock import patch, MagicMock
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37
38 from pathspec import PathSpec
39
40 # Import other test classes
41 from tests.util import (
42     THIS_DIR,
43     read_data,
44     DETERMINISTIC_HEADER,
45     BlackBaseTestCase,
46     DEFAULT_MODE,
47     fs,
48     ff,
49     dump_to_stderr,
50 )
51 from .test_primer import PrimerCLITests  # noqa: F401
52
53
54 THIS_FILE = Path(__file__)
55 PY36_VERSIONS = {
56     TargetVersion.PY36,
57     TargetVersion.PY37,
58     TargetVersion.PY38,
59     TargetVersion.PY39,
60 }
61 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
62 T = TypeVar("T")
63 R = TypeVar("R")
64
65
66 @contextmanager
67 def cache_dir(exists: bool = True) -> Iterator[Path]:
68     with TemporaryDirectory() as workspace:
69         cache_dir = Path(workspace)
70         if not exists:
71             cache_dir = cache_dir / "new"
72         with patch("black.CACHE_DIR", cache_dir):
73             yield cache_dir
74
75
76 @contextmanager
77 def event_loop() -> Iterator[None]:
78     policy = asyncio.get_event_loop_policy()
79     loop = policy.new_event_loop()
80     asyncio.set_event_loop(loop)
81     try:
82         yield
83
84     finally:
85         loop.close()
86
87
88 class FakeContext(click.Context):
89     """A fake click Context for when calling functions that need it."""
90
91     def __init__(self) -> None:
92         self.default_map: Dict[str, Any] = {}
93
94
95 class FakeParameter(click.Parameter):
96     """A fake click Parameter for when calling functions that need it."""
97
98     def __init__(self) -> None:
99         pass
100
101
102 class BlackRunner(CliRunner):
103     """Modify CliRunner so that stderr is not merged with stdout.
104
105     This is a hack that can be removed once we depend on Click 7.x"""
106
107     def __init__(self) -> None:
108         self.stderrbuf = BytesIO()
109         self.stdoutbuf = BytesIO()
110         self.stdout_bytes = b""
111         self.stderr_bytes = b""
112         super().__init__()
113
114     @contextmanager
115     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
116         with super().isolation(*args, **kwargs) as output:
117             try:
118                 hold_stderr = sys.stderr
119                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
120                 yield output
121             finally:
122                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
123                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
124                 sys.stderr = hold_stderr
125
126
127 class BlackTestCase(BlackBaseTestCase):
128     def invokeBlack(
129         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
130     ) -> None:
131         runner = BlackRunner()
132         if ignore_config:
133             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
134         result = runner.invoke(black.main, args)
135         self.assertEqual(
136             result.exit_code,
137             exit_code,
138             msg=(
139                 f"Failed with args: {args}\n"
140                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
141                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
142                 f"exception: {result.exception}"
143             ),
144         )
145
146     @patch("black.dump_to_file", dump_to_stderr)
147     def test_empty(self) -> None:
148         source = expected = ""
149         actual = fs(source)
150         self.assertFormatEqual(expected, actual)
151         black.assert_equivalent(source, actual)
152         black.assert_stable(source, actual, DEFAULT_MODE)
153
154     def test_empty_ff(self) -> None:
155         expected = ""
156         tmp_file = Path(black.dump_to_file())
157         try:
158             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
159             with open(tmp_file, encoding="utf8") as f:
160                 actual = f.read()
161         finally:
162             os.unlink(tmp_file)
163         self.assertFormatEqual(expected, actual)
164
165     def test_piping(self) -> None:
166         source, expected = read_data("src/black/__init__", data=False)
167         result = BlackRunner().invoke(
168             black.main,
169             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
170             input=BytesIO(source.encode("utf8")),
171         )
172         self.assertEqual(result.exit_code, 0)
173         self.assertFormatEqual(expected, result.output)
174         black.assert_equivalent(source, result.output)
175         black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("expression.py")
183         expected, _ = read_data("expression.diff")
184         config = THIS_DIR / "data" / "empty_pyproject.toml"
185         args = [
186             "-",
187             "--fast",
188             f"--line-length={black.DEFAULT_LINE_LENGTH}",
189             "--diff",
190             f"--config={config}",
191         ]
192         result = BlackRunner().invoke(
193             black.main, args, input=BytesIO(source.encode("utf8"))
194         )
195         self.assertEqual(result.exit_code, 0)
196         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
197         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
198         self.assertEqual(expected, actual)
199
200     def test_piping_diff_with_color(self) -> None:
201         source, _ = read_data("expression.py")
202         config = THIS_DIR / "data" / "empty_pyproject.toml"
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={config}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1;37m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     @unittest.expectedFailure
238     @patch("black.dump_to_file", dump_to_stderr)
239     def test_trailing_comma_optional_parens_stability1(self) -> None:
240         source, _expected = read_data("trailing_comma_optional_parens1")
241         actual = fs(source)
242         black.assert_stable(source, actual, DEFAULT_MODE)
243
244     @unittest.expectedFailure
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens2")
248         actual = fs(source)
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @unittest.expectedFailure
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability3(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens3")
255         actual = fs(source)
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens1")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens2")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     @patch("black.dump_to_file", dump_to_stderr)
271     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
272         source, _expected = read_data("trailing_comma_optional_parens3")
273         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
274         black.assert_stable(source, actual, DEFAULT_MODE)
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_pep_572(self) -> None:
278         source, expected = read_data("pep_572")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_stable(source, actual, DEFAULT_MODE)
282         if sys.version_info >= (3, 8):
283             black.assert_equivalent(source, actual)
284
285     @patch("black.dump_to_file", dump_to_stderr)
286     def test_pep_572_remove_parens(self) -> None:
287         source, expected = read_data("pep_572_remove_parens")
288         actual = fs(source)
289         self.assertFormatEqual(expected, actual)
290         black.assert_stable(source, actual, DEFAULT_MODE)
291         if sys.version_info >= (3, 8):
292             black.assert_equivalent(source, actual)
293
294     @patch("black.dump_to_file", dump_to_stderr)
295     def test_pep_572_do_not_remove_parens(self) -> None:
296         source, expected = read_data("pep_572_do_not_remove_parens")
297         # the AST safety checks will fail, but that's expected, just make sure no
298         # parentheses are touched
299         actual = black.format_str(source, mode=DEFAULT_MODE)
300         self.assertFormatEqual(expected, actual)
301
302     def test_pep_572_version_detection(self) -> None:
303         source, _ = read_data("pep_572")
304         root = black.lib2to3_parse(source)
305         features = black.get_features_used(root)
306         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
307         versions = black.detect_target_versions(root)
308         self.assertIn(black.TargetVersion.PY38, versions)
309
310     def test_expression_ff(self) -> None:
311         source, expected = read_data("expression")
312         tmp_file = Path(black.dump_to_file(source))
313         try:
314             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
315             with open(tmp_file, encoding="utf8") as f:
316                 actual = f.read()
317         finally:
318             os.unlink(tmp_file)
319         self.assertFormatEqual(expected, actual)
320         with patch("black.dump_to_file", dump_to_stderr):
321             black.assert_equivalent(source, actual)
322             black.assert_stable(source, actual, DEFAULT_MODE)
323
324     def test_expression_diff(self) -> None:
325         source, _ = read_data("expression.py")
326         config = THIS_DIR / "data" / "empty_pyproject.toml"
327         expected, _ = read_data("expression.diff")
328         tmp_file = Path(black.dump_to_file(source))
329         diff_header = re.compile(
330             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
331             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
332         )
333         try:
334             result = BlackRunner().invoke(
335                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
336             )
337             self.assertEqual(result.exit_code, 0)
338         finally:
339             os.unlink(tmp_file)
340         actual = result.output
341         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
342         if expected != actual:
343             dump = black.dump_to_file(actual)
344             msg = (
345                 "Expected diff isn't equal to the actual. If you made changes to"
346                 " expression.py and this is an anticipated difference, overwrite"
347                 f" tests/data/expression.diff with {dump}"
348             )
349             self.assertEqual(expected, actual, msg)
350
351     def test_expression_diff_with_color(self) -> None:
352         source, _ = read_data("expression.py")
353         config = THIS_DIR / "data" / "empty_pyproject.toml"
354         expected, _ = read_data("expression.diff")
355         tmp_file = Path(black.dump_to_file(source))
356         try:
357             result = BlackRunner().invoke(
358                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
359             )
360         finally:
361             os.unlink(tmp_file)
362         actual = result.output
363         # We check the contents of the diff in `test_expression_diff`. All
364         # we need to check here is that color codes exist in the result.
365         self.assertIn("\033[1;37m", actual)
366         self.assertIn("\033[36m", actual)
367         self.assertIn("\033[32m", actual)
368         self.assertIn("\033[31m", actual)
369         self.assertIn("\033[0m", actual)
370
371     @patch("black.dump_to_file", dump_to_stderr)
372     def test_pep_570(self) -> None:
373         source, expected = read_data("pep_570")
374         actual = fs(source)
375         self.assertFormatEqual(expected, actual)
376         black.assert_stable(source, actual, DEFAULT_MODE)
377         if sys.version_info >= (3, 8):
378             black.assert_equivalent(source, actual)
379
380     def test_detect_pos_only_arguments(self) -> None:
381         source, _ = read_data("pep_570")
382         root = black.lib2to3_parse(source)
383         features = black.get_features_used(root)
384         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
385         versions = black.detect_target_versions(root)
386         self.assertIn(black.TargetVersion.PY38, versions)
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_string_quotes(self) -> None:
390         source, expected = read_data("string_quotes")
391         mode = black.Mode(experimental_string_processing=True)
392         actual = fs(source, mode=mode)
393         self.assertFormatEqual(expected, actual)
394         black.assert_equivalent(source, actual)
395         black.assert_stable(source, actual, mode)
396         mode = replace(mode, string_normalization=False)
397         not_normalized = fs(source, mode=mode)
398         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
399         black.assert_equivalent(source, not_normalized)
400         black.assert_stable(source, not_normalized, mode=mode)
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_docstring_no_string_normalization(self) -> None:
404         """Like test_docstring but with string normalization off."""
405         source, expected = read_data("docstring_no_string_normalization")
406         mode = replace(DEFAULT_MODE, string_normalization=False)
407         actual = fs(source, mode=mode)
408         self.assertFormatEqual(expected, actual)
409         black.assert_equivalent(source, actual)
410         black.assert_stable(source, actual, mode)
411
412     def test_long_strings_flag_disabled(self) -> None:
413         """Tests for turning off the string processing logic."""
414         source, expected = read_data("long_strings_flag_disabled")
415         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
416         actual = fs(source, mode=mode)
417         self.assertFormatEqual(expected, actual)
418         black.assert_stable(expected, actual, mode)
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_numeric_literals(self) -> None:
422         source, expected = read_data("numeric_literals")
423         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
424         actual = fs(source, mode=mode)
425         self.assertFormatEqual(expected, actual)
426         black.assert_equivalent(source, actual)
427         black.assert_stable(source, actual, mode)
428
429     @patch("black.dump_to_file", dump_to_stderr)
430     def test_numeric_literals_ignoring_underscores(self) -> None:
431         source, expected = read_data("numeric_literals_skip_underscores")
432         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
433         actual = fs(source, mode=mode)
434         self.assertFormatEqual(expected, actual)
435         black.assert_equivalent(source, actual)
436         black.assert_stable(source, actual, mode)
437
438     def test_skip_magic_trailing_comma(self) -> None:
439         source, _ = read_data("expression.py")
440         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
441         tmp_file = Path(black.dump_to_file(source))
442         diff_header = re.compile(
443             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
444             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
445         )
446         try:
447             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
448             self.assertEqual(result.exit_code, 0)
449         finally:
450             os.unlink(tmp_file)
451         actual = result.output
452         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
453         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
454         if expected != actual:
455             dump = black.dump_to_file(actual)
456             msg = (
457                 "Expected diff isn't equal to the actual. If you made changes to"
458                 " expression.py and this is an anticipated difference, overwrite"
459                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
460             )
461             self.assertEqual(expected, actual, msg)
462
463     @pytest.mark.without_python2
464     def test_python2_should_fail_without_optional_install(self) -> None:
465         # python 3.7 and below will install typed-ast and will be able to parse Python 2
466         if sys.version_info < (3, 8):
467             return
468         source = "x = 1234l"
469         tmp_file = Path(black.dump_to_file(source))
470         try:
471             runner = BlackRunner()
472             result = runner.invoke(black.main, [str(tmp_file)])
473             self.assertEqual(result.exit_code, 123)
474         finally:
475             os.unlink(tmp_file)
476         actual = (
477             runner.stderr_bytes.decode()
478             .replace("\n", "")
479             .replace("\\n", "")
480             .replace("\\r", "")
481             .replace("\r", "")
482         )
483         msg = (
484             "The requested source code has invalid Python 3 syntax."
485             "If you are trying to format Python 2 files please reinstall Black"
486             " with the 'python2' extra: `python3 -m pip install black[python2]`."
487         )
488         self.assertIn(msg, actual)
489
490     @pytest.mark.python2
491     @patch("black.dump_to_file", dump_to_stderr)
492     def test_python2_print_function(self) -> None:
493         source, expected = read_data("python2_print_function")
494         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
495         actual = fs(source, mode=mode)
496         self.assertFormatEqual(expected, actual)
497         black.assert_equivalent(source, actual)
498         black.assert_stable(source, actual, mode)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_stub(self) -> None:
502         mode = replace(DEFAULT_MODE, is_pyi=True)
503         source, expected = read_data("stub.pyi")
504         actual = fs(source, mode=mode)
505         self.assertFormatEqual(expected, actual)
506         black.assert_stable(source, actual, mode)
507
508     @patch("black.dump_to_file", dump_to_stderr)
509     def test_async_as_identifier(self) -> None:
510         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
511         source, expected = read_data("async_as_identifier")
512         actual = fs(source)
513         self.assertFormatEqual(expected, actual)
514         major, minor = sys.version_info[:2]
515         if major < 3 or (major <= 3 and minor < 7):
516             black.assert_equivalent(source, actual)
517         black.assert_stable(source, actual, DEFAULT_MODE)
518         # ensure black can parse this when the target is 3.6
519         self.invokeBlack([str(source_path), "--target-version", "py36"])
520         # but not on 3.7, because async/await is no longer an identifier
521         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
522
523     @patch("black.dump_to_file", dump_to_stderr)
524     def test_python37(self) -> None:
525         source_path = (THIS_DIR / "data" / "python37.py").resolve()
526         source, expected = read_data("python37")
527         actual = fs(source)
528         self.assertFormatEqual(expected, actual)
529         major, minor = sys.version_info[:2]
530         if major > 3 or (major == 3 and minor >= 7):
531             black.assert_equivalent(source, actual)
532         black.assert_stable(source, actual, DEFAULT_MODE)
533         # ensure black can parse this when the target is 3.7
534         self.invokeBlack([str(source_path), "--target-version", "py37"])
535         # but not on 3.6, because we use async as a reserved keyword
536         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
537
538     @patch("black.dump_to_file", dump_to_stderr)
539     def test_python38(self) -> None:
540         source, expected = read_data("python38")
541         actual = fs(source)
542         self.assertFormatEqual(expected, actual)
543         major, minor = sys.version_info[:2]
544         if major > 3 or (major == 3 and minor >= 8):
545             black.assert_equivalent(source, actual)
546         black.assert_stable(source, actual, DEFAULT_MODE)
547
548     @patch("black.dump_to_file", dump_to_stderr)
549     def test_python39(self) -> None:
550         source, expected = read_data("python39")
551         actual = fs(source)
552         self.assertFormatEqual(expected, actual)
553         major, minor = sys.version_info[:2]
554         if major > 3 or (major == 3 and minor >= 9):
555             black.assert_equivalent(source, actual)
556         black.assert_stable(source, actual, DEFAULT_MODE)
557
558     def test_tab_comment_indentation(self) -> None:
559         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
560         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
561         self.assertFormatEqual(contents_spc, fs(contents_spc))
562         self.assertFormatEqual(contents_spc, fs(contents_tab))
563
564         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
565         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
566         self.assertFormatEqual(contents_spc, fs(contents_spc))
567         self.assertFormatEqual(contents_spc, fs(contents_tab))
568
569         # mixed tabs and spaces (valid Python 2 code)
570         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
571         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
572         self.assertFormatEqual(contents_spc, fs(contents_spc))
573         self.assertFormatEqual(contents_spc, fs(contents_tab))
574
575         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
576         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
577         self.assertFormatEqual(contents_spc, fs(contents_spc))
578         self.assertFormatEqual(contents_spc, fs(contents_tab))
579
580     def test_report_verbose(self) -> None:
581         report = black.Report(verbose=True)
582         out_lines = []
583         err_lines = []
584
585         def out(msg: str, **kwargs: Any) -> None:
586             out_lines.append(msg)
587
588         def err(msg: str, **kwargs: Any) -> None:
589             err_lines.append(msg)
590
591         with patch("black.out", out), patch("black.err", err):
592             report.done(Path("f1"), black.Changed.NO)
593             self.assertEqual(len(out_lines), 1)
594             self.assertEqual(len(err_lines), 0)
595             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
596             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
597             self.assertEqual(report.return_code, 0)
598             report.done(Path("f2"), black.Changed.YES)
599             self.assertEqual(len(out_lines), 2)
600             self.assertEqual(len(err_lines), 0)
601             self.assertEqual(out_lines[-1], "reformatted f2")
602             self.assertEqual(
603                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
604             )
605             report.done(Path("f3"), black.Changed.CACHED)
606             self.assertEqual(len(out_lines), 3)
607             self.assertEqual(len(err_lines), 0)
608             self.assertEqual(
609                 out_lines[-1], "f3 wasn't modified on disk since last run."
610             )
611             self.assertEqual(
612                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
613             )
614             self.assertEqual(report.return_code, 0)
615             report.check = True
616             self.assertEqual(report.return_code, 1)
617             report.check = False
618             report.failed(Path("e1"), "boom")
619             self.assertEqual(len(out_lines), 3)
620             self.assertEqual(len(err_lines), 1)
621             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
622             self.assertEqual(
623                 unstyle(str(report)),
624                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
625                 " reformat.",
626             )
627             self.assertEqual(report.return_code, 123)
628             report.done(Path("f3"), black.Changed.YES)
629             self.assertEqual(len(out_lines), 4)
630             self.assertEqual(len(err_lines), 1)
631             self.assertEqual(out_lines[-1], "reformatted f3")
632             self.assertEqual(
633                 unstyle(str(report)),
634                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
635                 " reformat.",
636             )
637             self.assertEqual(report.return_code, 123)
638             report.failed(Path("e2"), "boom")
639             self.assertEqual(len(out_lines), 4)
640             self.assertEqual(len(err_lines), 2)
641             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
642             self.assertEqual(
643                 unstyle(str(report)),
644                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
645                 " reformat.",
646             )
647             self.assertEqual(report.return_code, 123)
648             report.path_ignored(Path("wat"), "no match")
649             self.assertEqual(len(out_lines), 5)
650             self.assertEqual(len(err_lines), 2)
651             self.assertEqual(out_lines[-1], "wat ignored: no match")
652             self.assertEqual(
653                 unstyle(str(report)),
654                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
655                 " reformat.",
656             )
657             self.assertEqual(report.return_code, 123)
658             report.done(Path("f4"), black.Changed.NO)
659             self.assertEqual(len(out_lines), 6)
660             self.assertEqual(len(err_lines), 2)
661             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
662             self.assertEqual(
663                 unstyle(str(report)),
664                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
665                 " reformat.",
666             )
667             self.assertEqual(report.return_code, 123)
668             report.check = True
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
672                 " would fail to reformat.",
673             )
674             report.check = False
675             report.diff = True
676             self.assertEqual(
677                 unstyle(str(report)),
678                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
679                 " would fail to reformat.",
680             )
681
682     def test_report_quiet(self) -> None:
683         report = black.Report(quiet=True)
684         out_lines = []
685         err_lines = []
686
687         def out(msg: str, **kwargs: Any) -> None:
688             out_lines.append(msg)
689
690         def err(msg: str, **kwargs: Any) -> None:
691             err_lines.append(msg)
692
693         with patch("black.out", out), patch("black.err", err):
694             report.done(Path("f1"), black.Changed.NO)
695             self.assertEqual(len(out_lines), 0)
696             self.assertEqual(len(err_lines), 0)
697             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
698             self.assertEqual(report.return_code, 0)
699             report.done(Path("f2"), black.Changed.YES)
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 0)
702             self.assertEqual(
703                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
704             )
705             report.done(Path("f3"), black.Changed.CACHED)
706             self.assertEqual(len(out_lines), 0)
707             self.assertEqual(len(err_lines), 0)
708             self.assertEqual(
709                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
710             )
711             self.assertEqual(report.return_code, 0)
712             report.check = True
713             self.assertEqual(report.return_code, 1)
714             report.check = False
715             report.failed(Path("e1"), "boom")
716             self.assertEqual(len(out_lines), 0)
717             self.assertEqual(len(err_lines), 1)
718             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
719             self.assertEqual(
720                 unstyle(str(report)),
721                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
722                 " reformat.",
723             )
724             self.assertEqual(report.return_code, 123)
725             report.done(Path("f3"), black.Changed.YES)
726             self.assertEqual(len(out_lines), 0)
727             self.assertEqual(len(err_lines), 1)
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
731                 " reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.failed(Path("e2"), "boom")
735             self.assertEqual(len(out_lines), 0)
736             self.assertEqual(len(err_lines), 2)
737             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
738             self.assertEqual(
739                 unstyle(str(report)),
740                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
741                 " reformat.",
742             )
743             self.assertEqual(report.return_code, 123)
744             report.path_ignored(Path("wat"), "no match")
745             self.assertEqual(len(out_lines), 0)
746             self.assertEqual(len(err_lines), 2)
747             self.assertEqual(
748                 unstyle(str(report)),
749                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
750                 " reformat.",
751             )
752             self.assertEqual(report.return_code, 123)
753             report.done(Path("f4"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 2)
756             self.assertEqual(
757                 unstyle(str(report)),
758                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
759                 " reformat.",
760             )
761             self.assertEqual(report.return_code, 123)
762             report.check = True
763             self.assertEqual(
764                 unstyle(str(report)),
765                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
766                 " would fail to reformat.",
767             )
768             report.check = False
769             report.diff = True
770             self.assertEqual(
771                 unstyle(str(report)),
772                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
773                 " would fail to reformat.",
774             )
775
776     def test_report_normal(self) -> None:
777         report = black.Report()
778         out_lines = []
779         err_lines = []
780
781         def out(msg: str, **kwargs: Any) -> None:
782             out_lines.append(msg)
783
784         def err(msg: str, **kwargs: Any) -> None:
785             err_lines.append(msg)
786
787         with patch("black.out", out), patch("black.err", err):
788             report.done(Path("f1"), black.Changed.NO)
789             self.assertEqual(len(out_lines), 0)
790             self.assertEqual(len(err_lines), 0)
791             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
792             self.assertEqual(report.return_code, 0)
793             report.done(Path("f2"), black.Changed.YES)
794             self.assertEqual(len(out_lines), 1)
795             self.assertEqual(len(err_lines), 0)
796             self.assertEqual(out_lines[-1], "reformatted f2")
797             self.assertEqual(
798                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
799             )
800             report.done(Path("f3"), black.Changed.CACHED)
801             self.assertEqual(len(out_lines), 1)
802             self.assertEqual(len(err_lines), 0)
803             self.assertEqual(out_lines[-1], "reformatted f2")
804             self.assertEqual(
805                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
806             )
807             self.assertEqual(report.return_code, 0)
808             report.check = True
809             self.assertEqual(report.return_code, 1)
810             report.check = False
811             report.failed(Path("e1"), "boom")
812             self.assertEqual(len(out_lines), 1)
813             self.assertEqual(len(err_lines), 1)
814             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
818                 " reformat.",
819             )
820             self.assertEqual(report.return_code, 123)
821             report.done(Path("f3"), black.Changed.YES)
822             self.assertEqual(len(out_lines), 2)
823             self.assertEqual(len(err_lines), 1)
824             self.assertEqual(out_lines[-1], "reformatted f3")
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
828                 " reformat.",
829             )
830             self.assertEqual(report.return_code, 123)
831             report.failed(Path("e2"), "boom")
832             self.assertEqual(len(out_lines), 2)
833             self.assertEqual(len(err_lines), 2)
834             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
835             self.assertEqual(
836                 unstyle(str(report)),
837                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
838                 " reformat.",
839             )
840             self.assertEqual(report.return_code, 123)
841             report.path_ignored(Path("wat"), "no match")
842             self.assertEqual(len(out_lines), 2)
843             self.assertEqual(len(err_lines), 2)
844             self.assertEqual(
845                 unstyle(str(report)),
846                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
847                 " reformat.",
848             )
849             self.assertEqual(report.return_code, 123)
850             report.done(Path("f4"), black.Changed.NO)
851             self.assertEqual(len(out_lines), 2)
852             self.assertEqual(len(err_lines), 2)
853             self.assertEqual(
854                 unstyle(str(report)),
855                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
856                 " reformat.",
857             )
858             self.assertEqual(report.return_code, 123)
859             report.check = True
860             self.assertEqual(
861                 unstyle(str(report)),
862                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
863                 " would fail to reformat.",
864             )
865             report.check = False
866             report.diff = True
867             self.assertEqual(
868                 unstyle(str(report)),
869                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
870                 " would fail to reformat.",
871             )
872
873     def test_lib2to3_parse(self) -> None:
874         with self.assertRaises(black.InvalidInput):
875             black.lib2to3_parse("invalid syntax")
876
877         straddling = "x + y"
878         black.lib2to3_parse(straddling)
879         black.lib2to3_parse(straddling, {TargetVersion.PY27})
880         black.lib2to3_parse(straddling, {TargetVersion.PY36})
881         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
882
883         py2_only = "print x"
884         black.lib2to3_parse(py2_only)
885         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
886         with self.assertRaises(black.InvalidInput):
887             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
888         with self.assertRaises(black.InvalidInput):
889             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
890
891         py3_only = "exec(x, end=y)"
892         black.lib2to3_parse(py3_only)
893         with self.assertRaises(black.InvalidInput):
894             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
895         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
896         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
897
898     def test_get_features_used_decorator(self) -> None:
899         # Test the feature detection of new decorator syntax
900         # since this makes some test cases of test_get_features_used()
901         # fails if it fails, this is tested first so that a useful case
902         # is identified
903         simples, relaxed = read_data("decorators")
904         # skip explanation comments at the top of the file
905         for simple_test in simples.split("##")[1:]:
906             node = black.lib2to3_parse(simple_test)
907             decorator = str(node.children[0].children[0]).strip()
908             self.assertNotIn(
909                 Feature.RELAXED_DECORATORS,
910                 black.get_features_used(node),
911                 msg=(
912                     f"decorator '{decorator}' follows python<=3.8 syntax"
913                     "but is detected as 3.9+"
914                     # f"The full node is\n{node!r}"
915                 ),
916             )
917         # skip the '# output' comment at the top of the output part
918         for relaxed_test in relaxed.split("##")[1:]:
919             node = black.lib2to3_parse(relaxed_test)
920             decorator = str(node.children[0].children[0]).strip()
921             self.assertIn(
922                 Feature.RELAXED_DECORATORS,
923                 black.get_features_used(node),
924                 msg=(
925                     f"decorator '{decorator}' uses python3.9+ syntax"
926                     "but is detected as python<=3.8"
927                     # f"The full node is\n{node!r}"
928                 ),
929             )
930
931     def test_get_features_used(self) -> None:
932         node = black.lib2to3_parse("def f(*, arg): ...\n")
933         self.assertEqual(black.get_features_used(node), set())
934         node = black.lib2to3_parse("def f(*, arg,): ...\n")
935         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
936         node = black.lib2to3_parse("f(*arg,)\n")
937         self.assertEqual(
938             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
939         )
940         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
941         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
942         node = black.lib2to3_parse("123_456\n")
943         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
944         node = black.lib2to3_parse("123456\n")
945         self.assertEqual(black.get_features_used(node), set())
946         source, expected = read_data("function")
947         node = black.lib2to3_parse(source)
948         expected_features = {
949             Feature.TRAILING_COMMA_IN_CALL,
950             Feature.TRAILING_COMMA_IN_DEF,
951             Feature.F_STRINGS,
952         }
953         self.assertEqual(black.get_features_used(node), expected_features)
954         node = black.lib2to3_parse(expected)
955         self.assertEqual(black.get_features_used(node), expected_features)
956         source, expected = read_data("expression")
957         node = black.lib2to3_parse(source)
958         self.assertEqual(black.get_features_used(node), set())
959         node = black.lib2to3_parse(expected)
960         self.assertEqual(black.get_features_used(node), set())
961
962     def test_get_future_imports(self) -> None:
963         node = black.lib2to3_parse("\n")
964         self.assertEqual(set(), black.get_future_imports(node))
965         node = black.lib2to3_parse("from __future__ import black\n")
966         self.assertEqual({"black"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
968         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
969         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
970         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
971         node = black.lib2to3_parse(
972             "from __future__ import multiple\nfrom __future__ import imports\n"
973         )
974         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
975         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
976         self.assertEqual({"black"}, black.get_future_imports(node))
977         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
978         self.assertEqual({"black"}, black.get_future_imports(node))
979         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
980         self.assertEqual(set(), black.get_future_imports(node))
981         node = black.lib2to3_parse("from some.module import black\n")
982         self.assertEqual(set(), black.get_future_imports(node))
983         node = black.lib2to3_parse(
984             "from __future__ import unicode_literals as _unicode_literals"
985         )
986         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
987         node = black.lib2to3_parse(
988             "from __future__ import unicode_literals as _lol, print"
989         )
990         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
991
992     def test_debug_visitor(self) -> None:
993         source, _ = read_data("debug_visitor.py")
994         expected, _ = read_data("debug_visitor.out")
995         out_lines = []
996         err_lines = []
997
998         def out(msg: str, **kwargs: Any) -> None:
999             out_lines.append(msg)
1000
1001         def err(msg: str, **kwargs: Any) -> None:
1002             err_lines.append(msg)
1003
1004         with patch("black.out", out), patch("black.err", err):
1005             black.DebugVisitor.show(source)
1006         actual = "\n".join(out_lines) + "\n"
1007         log_name = ""
1008         if expected != actual:
1009             log_name = black.dump_to_file(*out_lines)
1010         self.assertEqual(
1011             expected,
1012             actual,
1013             f"AST print out is different. Actual version dumped to {log_name}",
1014         )
1015
1016     def test_format_file_contents(self) -> None:
1017         empty = ""
1018         mode = DEFAULT_MODE
1019         with self.assertRaises(black.NothingChanged):
1020             black.format_file_contents(empty, mode=mode, fast=False)
1021         just_nl = "\n"
1022         with self.assertRaises(black.NothingChanged):
1023             black.format_file_contents(just_nl, mode=mode, fast=False)
1024         same = "j = [1, 2, 3]\n"
1025         with self.assertRaises(black.NothingChanged):
1026             black.format_file_contents(same, mode=mode, fast=False)
1027         different = "j = [1,2,3]"
1028         expected = same
1029         actual = black.format_file_contents(different, mode=mode, fast=False)
1030         self.assertEqual(expected, actual)
1031         invalid = "return if you can"
1032         with self.assertRaises(black.InvalidInput) as e:
1033             black.format_file_contents(invalid, mode=mode, fast=False)
1034         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1035
1036     def test_endmarker(self) -> None:
1037         n = black.lib2to3_parse("\n")
1038         self.assertEqual(n.type, black.syms.file_input)
1039         self.assertEqual(len(n.children), 1)
1040         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1041
1042     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1043     def test_assertFormatEqual(self) -> None:
1044         out_lines = []
1045         err_lines = []
1046
1047         def out(msg: str, **kwargs: Any) -> None:
1048             out_lines.append(msg)
1049
1050         def err(msg: str, **kwargs: Any) -> None:
1051             err_lines.append(msg)
1052
1053         with patch("black.out", out), patch("black.err", err):
1054             with self.assertRaises(AssertionError):
1055                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1056
1057         out_str = "".join(out_lines)
1058         self.assertTrue("Expected tree:" in out_str)
1059         self.assertTrue("Actual tree:" in out_str)
1060         self.assertEqual("".join(err_lines), "")
1061
1062     def test_cache_broken_file(self) -> None:
1063         mode = DEFAULT_MODE
1064         with cache_dir() as workspace:
1065             cache_file = black.get_cache_file(mode)
1066             with cache_file.open("w") as fobj:
1067                 fobj.write("this is not a pickle")
1068             self.assertEqual(black.read_cache(mode), {})
1069             src = (workspace / "test.py").resolve()
1070             with src.open("w") as fobj:
1071                 fobj.write("print('hello')")
1072             self.invokeBlack([str(src)])
1073             cache = black.read_cache(mode)
1074             self.assertIn(str(src), cache)
1075
1076     def test_cache_single_file_already_cached(self) -> None:
1077         mode = DEFAULT_MODE
1078         with cache_dir() as workspace:
1079             src = (workspace / "test.py").resolve()
1080             with src.open("w") as fobj:
1081                 fobj.write("print('hello')")
1082             black.write_cache({}, [src], mode)
1083             self.invokeBlack([str(src)])
1084             with src.open("r") as fobj:
1085                 self.assertEqual(fobj.read(), "print('hello')")
1086
1087     @event_loop()
1088     def test_cache_multiple_files(self) -> None:
1089         mode = DEFAULT_MODE
1090         with cache_dir() as workspace, patch(
1091             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1092         ):
1093             one = (workspace / "one.py").resolve()
1094             with one.open("w") as fobj:
1095                 fobj.write("print('hello')")
1096             two = (workspace / "two.py").resolve()
1097             with two.open("w") as fobj:
1098                 fobj.write("print('hello')")
1099             black.write_cache({}, [one], mode)
1100             self.invokeBlack([str(workspace)])
1101             with one.open("r") as fobj:
1102                 self.assertEqual(fobj.read(), "print('hello')")
1103             with two.open("r") as fobj:
1104                 self.assertEqual(fobj.read(), 'print("hello")\n')
1105             cache = black.read_cache(mode)
1106             self.assertIn(str(one), cache)
1107             self.assertIn(str(two), cache)
1108
1109     def test_no_cache_when_writeback_diff(self) -> None:
1110         mode = DEFAULT_MODE
1111         with cache_dir() as workspace:
1112             src = (workspace / "test.py").resolve()
1113             with src.open("w") as fobj:
1114                 fobj.write("print('hello')")
1115             with patch("black.read_cache") as read_cache, patch(
1116                 "black.write_cache"
1117             ) as write_cache:
1118                 self.invokeBlack([str(src), "--diff"])
1119                 cache_file = black.get_cache_file(mode)
1120                 self.assertFalse(cache_file.exists())
1121                 write_cache.assert_not_called()
1122                 read_cache.assert_not_called()
1123
1124     def test_no_cache_when_writeback_color_diff(self) -> None:
1125         mode = DEFAULT_MODE
1126         with cache_dir() as workspace:
1127             src = (workspace / "test.py").resolve()
1128             with src.open("w") as fobj:
1129                 fobj.write("print('hello')")
1130             with patch("black.read_cache") as read_cache, patch(
1131                 "black.write_cache"
1132             ) as write_cache:
1133                 self.invokeBlack([str(src), "--diff", "--color"])
1134                 cache_file = black.get_cache_file(mode)
1135                 self.assertFalse(cache_file.exists())
1136                 write_cache.assert_not_called()
1137                 read_cache.assert_not_called()
1138
1139     @event_loop()
1140     def test_output_locking_when_writeback_diff(self) -> None:
1141         with cache_dir() as workspace:
1142             for tag in range(0, 4):
1143                 src = (workspace / f"test{tag}.py").resolve()
1144                 with src.open("w") as fobj:
1145                     fobj.write("print('hello')")
1146             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1147                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1148                 # this isn't quite doing what we want, but if it _isn't_
1149                 # called then we cannot be using the lock it provides
1150                 mgr.assert_called()
1151
1152     @event_loop()
1153     def test_output_locking_when_writeback_color_diff(self) -> None:
1154         with cache_dir() as workspace:
1155             for tag in range(0, 4):
1156                 src = (workspace / f"test{tag}.py").resolve()
1157                 with src.open("w") as fobj:
1158                     fobj.write("print('hello')")
1159             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1160                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1161                 # this isn't quite doing what we want, but if it _isn't_
1162                 # called then we cannot be using the lock it provides
1163                 mgr.assert_called()
1164
1165     def test_no_cache_when_stdin(self) -> None:
1166         mode = DEFAULT_MODE
1167         with cache_dir():
1168             result = CliRunner().invoke(
1169                 black.main, ["-"], input=BytesIO(b"print('hello')")
1170             )
1171             self.assertEqual(result.exit_code, 0)
1172             cache_file = black.get_cache_file(mode)
1173             self.assertFalse(cache_file.exists())
1174
1175     def test_read_cache_no_cachefile(self) -> None:
1176         mode = DEFAULT_MODE
1177         with cache_dir():
1178             self.assertEqual(black.read_cache(mode), {})
1179
1180     def test_write_cache_read_cache(self) -> None:
1181         mode = DEFAULT_MODE
1182         with cache_dir() as workspace:
1183             src = (workspace / "test.py").resolve()
1184             src.touch()
1185             black.write_cache({}, [src], mode)
1186             cache = black.read_cache(mode)
1187             self.assertIn(str(src), cache)
1188             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1189
1190     def test_filter_cached(self) -> None:
1191         with TemporaryDirectory() as workspace:
1192             path = Path(workspace)
1193             uncached = (path / "uncached").resolve()
1194             cached = (path / "cached").resolve()
1195             cached_but_changed = (path / "changed").resolve()
1196             uncached.touch()
1197             cached.touch()
1198             cached_but_changed.touch()
1199             cache = {
1200                 str(cached): black.get_cache_info(cached),
1201                 str(cached_but_changed): (0.0, 0),
1202             }
1203             todo, done = black.filter_cached(
1204                 cache, {uncached, cached, cached_but_changed}
1205             )
1206             self.assertEqual(todo, {uncached, cached_but_changed})
1207             self.assertEqual(done, {cached})
1208
1209     def test_write_cache_creates_directory_if_needed(self) -> None:
1210         mode = DEFAULT_MODE
1211         with cache_dir(exists=False) as workspace:
1212             self.assertFalse(workspace.exists())
1213             black.write_cache({}, [], mode)
1214             self.assertTrue(workspace.exists())
1215
1216     @event_loop()
1217     def test_failed_formatting_does_not_get_cached(self) -> None:
1218         mode = DEFAULT_MODE
1219         with cache_dir() as workspace, patch(
1220             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1221         ):
1222             failing = (workspace / "failing.py").resolve()
1223             with failing.open("w") as fobj:
1224                 fobj.write("not actually python")
1225             clean = (workspace / "clean.py").resolve()
1226             with clean.open("w") as fobj:
1227                 fobj.write('print("hello")\n')
1228             self.invokeBlack([str(workspace)], exit_code=123)
1229             cache = black.read_cache(mode)
1230             self.assertNotIn(str(failing), cache)
1231             self.assertIn(str(clean), cache)
1232
1233     def test_write_cache_write_fail(self) -> None:
1234         mode = DEFAULT_MODE
1235         with cache_dir(), patch.object(Path, "open") as mock:
1236             mock.side_effect = OSError
1237             black.write_cache({}, [], mode)
1238
1239     @event_loop()
1240     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1241     def test_works_in_mono_process_only_environment(self) -> None:
1242         with cache_dir() as workspace:
1243             for f in [
1244                 (workspace / "one.py").resolve(),
1245                 (workspace / "two.py").resolve(),
1246             ]:
1247                 f.write_text('print("hello")\n')
1248             self.invokeBlack([str(workspace)])
1249
1250     @event_loop()
1251     def test_check_diff_use_together(self) -> None:
1252         with cache_dir():
1253             # Files which will be reformatted.
1254             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1255             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1256             # Files which will not be reformatted.
1257             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1258             self.invokeBlack([str(src2), "--diff", "--check"])
1259             # Multi file command.
1260             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1261
1262     def test_no_files(self) -> None:
1263         with cache_dir():
1264             # Without an argument, black exits with error code 0.
1265             self.invokeBlack([])
1266
1267     def test_broken_symlink(self) -> None:
1268         with cache_dir() as workspace:
1269             symlink = workspace / "broken_link.py"
1270             try:
1271                 symlink.symlink_to("nonexistent.py")
1272             except OSError as e:
1273                 self.skipTest(f"Can't create symlinks: {e}")
1274             self.invokeBlack([str(workspace.resolve())])
1275
1276     def test_read_cache_line_lengths(self) -> None:
1277         mode = DEFAULT_MODE
1278         short_mode = replace(DEFAULT_MODE, line_length=1)
1279         with cache_dir() as workspace:
1280             path = (workspace / "file.py").resolve()
1281             path.touch()
1282             black.write_cache({}, [path], mode)
1283             one = black.read_cache(mode)
1284             self.assertIn(str(path), one)
1285             two = black.read_cache(short_mode)
1286             self.assertNotIn(str(path), two)
1287
1288     def test_single_file_force_pyi(self) -> None:
1289         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1290         contents, expected = read_data("force_pyi")
1291         with cache_dir() as workspace:
1292             path = (workspace / "file.py").resolve()
1293             with open(path, "w") as fh:
1294                 fh.write(contents)
1295             self.invokeBlack([str(path), "--pyi"])
1296             with open(path, "r") as fh:
1297                 actual = fh.read()
1298             # verify cache with --pyi is separate
1299             pyi_cache = black.read_cache(pyi_mode)
1300             self.assertIn(str(path), pyi_cache)
1301             normal_cache = black.read_cache(DEFAULT_MODE)
1302             self.assertNotIn(str(path), normal_cache)
1303         self.assertFormatEqual(expected, actual)
1304         black.assert_equivalent(contents, actual)
1305         black.assert_stable(contents, actual, pyi_mode)
1306
1307     @event_loop()
1308     def test_multi_file_force_pyi(self) -> None:
1309         reg_mode = DEFAULT_MODE
1310         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1311         contents, expected = read_data("force_pyi")
1312         with cache_dir() as workspace:
1313             paths = [
1314                 (workspace / "file1.py").resolve(),
1315                 (workspace / "file2.py").resolve(),
1316             ]
1317             for path in paths:
1318                 with open(path, "w") as fh:
1319                     fh.write(contents)
1320             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1321             for path in paths:
1322                 with open(path, "r") as fh:
1323                     actual = fh.read()
1324                 self.assertEqual(actual, expected)
1325             # verify cache with --pyi is separate
1326             pyi_cache = black.read_cache(pyi_mode)
1327             normal_cache = black.read_cache(reg_mode)
1328             for path in paths:
1329                 self.assertIn(str(path), pyi_cache)
1330                 self.assertNotIn(str(path), normal_cache)
1331
1332     def test_pipe_force_pyi(self) -> None:
1333         source, expected = read_data("force_pyi")
1334         result = CliRunner().invoke(
1335             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1336         )
1337         self.assertEqual(result.exit_code, 0)
1338         actual = result.output
1339         self.assertFormatEqual(actual, expected)
1340
1341     def test_single_file_force_py36(self) -> None:
1342         reg_mode = DEFAULT_MODE
1343         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1344         source, expected = read_data("force_py36")
1345         with cache_dir() as workspace:
1346             path = (workspace / "file.py").resolve()
1347             with open(path, "w") as fh:
1348                 fh.write(source)
1349             self.invokeBlack([str(path), *PY36_ARGS])
1350             with open(path, "r") as fh:
1351                 actual = fh.read()
1352             # verify cache with --target-version is separate
1353             py36_cache = black.read_cache(py36_mode)
1354             self.assertIn(str(path), py36_cache)
1355             normal_cache = black.read_cache(reg_mode)
1356             self.assertNotIn(str(path), normal_cache)
1357         self.assertEqual(actual, expected)
1358
1359     @event_loop()
1360     def test_multi_file_force_py36(self) -> None:
1361         reg_mode = DEFAULT_MODE
1362         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1363         source, expected = read_data("force_py36")
1364         with cache_dir() as workspace:
1365             paths = [
1366                 (workspace / "file1.py").resolve(),
1367                 (workspace / "file2.py").resolve(),
1368             ]
1369             for path in paths:
1370                 with open(path, "w") as fh:
1371                     fh.write(source)
1372             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1373             for path in paths:
1374                 with open(path, "r") as fh:
1375                     actual = fh.read()
1376                 self.assertEqual(actual, expected)
1377             # verify cache with --target-version is separate
1378             pyi_cache = black.read_cache(py36_mode)
1379             normal_cache = black.read_cache(reg_mode)
1380             for path in paths:
1381                 self.assertIn(str(path), pyi_cache)
1382                 self.assertNotIn(str(path), normal_cache)
1383
1384     def test_pipe_force_py36(self) -> None:
1385         source, expected = read_data("force_py36")
1386         result = CliRunner().invoke(
1387             black.main,
1388             ["-", "-q", "--target-version=py36"],
1389             input=BytesIO(source.encode("utf8")),
1390         )
1391         self.assertEqual(result.exit_code, 0)
1392         actual = result.output
1393         self.assertFormatEqual(actual, expected)
1394
1395     def test_include_exclude(self) -> None:
1396         path = THIS_DIR / "data" / "include_exclude_tests"
1397         include = re.compile(r"\.pyi?$")
1398         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1399         report = black.Report()
1400         gitignore = PathSpec.from_lines("gitwildmatch", [])
1401         sources: List[Path] = []
1402         expected = [
1403             Path(path / "b/dont_exclude/a.py"),
1404             Path(path / "b/dont_exclude/a.pyi"),
1405         ]
1406         this_abs = THIS_DIR.resolve()
1407         sources.extend(
1408             black.gen_python_files(
1409                 path.iterdir(),
1410                 this_abs,
1411                 include,
1412                 exclude,
1413                 None,
1414                 None,
1415                 report,
1416                 gitignore,
1417             )
1418         )
1419         self.assertEqual(sorted(expected), sorted(sources))
1420
1421     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1422     def test_exclude_for_issue_1572(self) -> None:
1423         # Exclude shouldn't touch files that were explicitly given to Black through the
1424         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1425         # https://github.com/psf/black/issues/1572
1426         path = THIS_DIR / "data" / "include_exclude_tests"
1427         include = ""
1428         exclude = r"/exclude/|a\.py"
1429         src = str(path / "b/exclude/a.py")
1430         report = black.Report()
1431         expected = [Path(path / "b/exclude/a.py")]
1432         sources = list(
1433             black.get_sources(
1434                 ctx=FakeContext(),
1435                 src=(src,),
1436                 quiet=True,
1437                 verbose=False,
1438                 include=re.compile(include),
1439                 exclude=re.compile(exclude),
1440                 extend_exclude=None,
1441                 force_exclude=None,
1442                 report=report,
1443                 stdin_filename=None,
1444             )
1445         )
1446         self.assertEqual(sorted(expected), sorted(sources))
1447
1448     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1449     def test_get_sources_with_stdin(self) -> None:
1450         include = ""
1451         exclude = r"/exclude/|a\.py"
1452         src = "-"
1453         report = black.Report()
1454         expected = [Path("-")]
1455         sources = list(
1456             black.get_sources(
1457                 ctx=FakeContext(),
1458                 src=(src,),
1459                 quiet=True,
1460                 verbose=False,
1461                 include=re.compile(include),
1462                 exclude=re.compile(exclude),
1463                 extend_exclude=None,
1464                 force_exclude=None,
1465                 report=report,
1466                 stdin_filename=None,
1467             )
1468         )
1469         self.assertEqual(sorted(expected), sorted(sources))
1470
1471     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1472     def test_get_sources_with_stdin_filename(self) -> None:
1473         include = ""
1474         exclude = r"/exclude/|a\.py"
1475         src = "-"
1476         report = black.Report()
1477         stdin_filename = str(THIS_DIR / "data/collections.py")
1478         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1479         sources = list(
1480             black.get_sources(
1481                 ctx=FakeContext(),
1482                 src=(src,),
1483                 quiet=True,
1484                 verbose=False,
1485                 include=re.compile(include),
1486                 exclude=re.compile(exclude),
1487                 extend_exclude=None,
1488                 force_exclude=None,
1489                 report=report,
1490                 stdin_filename=stdin_filename,
1491             )
1492         )
1493         self.assertEqual(sorted(expected), sorted(sources))
1494
1495     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1496     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1497         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1498         # file being passed directly. This is the same as
1499         # test_exclude_for_issue_1572
1500         path = THIS_DIR / "data" / "include_exclude_tests"
1501         include = ""
1502         exclude = r"/exclude/|a\.py"
1503         src = "-"
1504         report = black.Report()
1505         stdin_filename = str(path / "b/exclude/a.py")
1506         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1507         sources = list(
1508             black.get_sources(
1509                 ctx=FakeContext(),
1510                 src=(src,),
1511                 quiet=True,
1512                 verbose=False,
1513                 include=re.compile(include),
1514                 exclude=re.compile(exclude),
1515                 extend_exclude=None,
1516                 force_exclude=None,
1517                 report=report,
1518                 stdin_filename=stdin_filename,
1519             )
1520         )
1521         self.assertEqual(sorted(expected), sorted(sources))
1522
1523     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1524     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1525         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1526         # file being passed directly. This is the same as
1527         # test_exclude_for_issue_1572
1528         path = THIS_DIR / "data" / "include_exclude_tests"
1529         include = ""
1530         extend_exclude = r"/exclude/|a\.py"
1531         src = "-"
1532         report = black.Report()
1533         stdin_filename = str(path / "b/exclude/a.py")
1534         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1535         sources = list(
1536             black.get_sources(
1537                 ctx=FakeContext(),
1538                 src=(src,),
1539                 quiet=True,
1540                 verbose=False,
1541                 include=re.compile(include),
1542                 exclude=re.compile(""),
1543                 extend_exclude=re.compile(extend_exclude),
1544                 force_exclude=None,
1545                 report=report,
1546                 stdin_filename=stdin_filename,
1547             )
1548         )
1549         self.assertEqual(sorted(expected), sorted(sources))
1550
1551     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1552     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1553         # Force exclude should exclude the file when passing it through
1554         # stdin_filename
1555         path = THIS_DIR / "data" / "include_exclude_tests"
1556         include = ""
1557         force_exclude = r"/exclude/|a\.py"
1558         src = "-"
1559         report = black.Report()
1560         stdin_filename = str(path / "b/exclude/a.py")
1561         sources = list(
1562             black.get_sources(
1563                 ctx=FakeContext(),
1564                 src=(src,),
1565                 quiet=True,
1566                 verbose=False,
1567                 include=re.compile(include),
1568                 exclude=re.compile(""),
1569                 extend_exclude=None,
1570                 force_exclude=re.compile(force_exclude),
1571                 report=report,
1572                 stdin_filename=stdin_filename,
1573             )
1574         )
1575         self.assertEqual([], sorted(sources))
1576
1577     def test_reformat_one_with_stdin(self) -> None:
1578         with patch(
1579             "black.format_stdin_to_stdout",
1580             return_value=lambda *args, **kwargs: black.Changed.YES,
1581         ) as fsts:
1582             report = MagicMock()
1583             path = Path("-")
1584             black.reformat_one(
1585                 path,
1586                 fast=True,
1587                 write_back=black.WriteBack.YES,
1588                 mode=DEFAULT_MODE,
1589                 report=report,
1590             )
1591             fsts.assert_called_once()
1592             report.done.assert_called_with(path, black.Changed.YES)
1593
1594     def test_reformat_one_with_stdin_filename(self) -> None:
1595         with patch(
1596             "black.format_stdin_to_stdout",
1597             return_value=lambda *args, **kwargs: black.Changed.YES,
1598         ) as fsts:
1599             report = MagicMock()
1600             p = "foo.py"
1601             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1602             expected = Path(p)
1603             black.reformat_one(
1604                 path,
1605                 fast=True,
1606                 write_back=black.WriteBack.YES,
1607                 mode=DEFAULT_MODE,
1608                 report=report,
1609             )
1610             fsts.assert_called_once_with(
1611                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1612             )
1613             # __BLACK_STDIN_FILENAME__ should have been stripped
1614             report.done.assert_called_with(expected, black.Changed.YES)
1615
1616     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1617         with patch(
1618             "black.format_stdin_to_stdout",
1619             return_value=lambda *args, **kwargs: black.Changed.YES,
1620         ) as fsts:
1621             report = MagicMock()
1622             p = "foo.pyi"
1623             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1624             expected = Path(p)
1625             black.reformat_one(
1626                 path,
1627                 fast=True,
1628                 write_back=black.WriteBack.YES,
1629                 mode=DEFAULT_MODE,
1630                 report=report,
1631             )
1632             fsts.assert_called_once_with(
1633                 fast=True,
1634                 write_back=black.WriteBack.YES,
1635                 mode=replace(DEFAULT_MODE, is_pyi=True),
1636             )
1637             # __BLACK_STDIN_FILENAME__ should have been stripped
1638             report.done.assert_called_with(expected, black.Changed.YES)
1639
1640     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1641         with patch(
1642             "black.format_stdin_to_stdout",
1643             return_value=lambda *args, **kwargs: black.Changed.YES,
1644         ) as fsts:
1645             report = MagicMock()
1646             # Even with an existing file, since we are forcing stdin, black
1647             # should output to stdout and not modify the file inplace
1648             p = Path(str(THIS_DIR / "data/collections.py"))
1649             # Make sure is_file actually returns True
1650             self.assertTrue(p.is_file())
1651             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1652             expected = Path(p)
1653             black.reformat_one(
1654                 path,
1655                 fast=True,
1656                 write_back=black.WriteBack.YES,
1657                 mode=DEFAULT_MODE,
1658                 report=report,
1659             )
1660             fsts.assert_called_once()
1661             # __BLACK_STDIN_FILENAME__ should have been stripped
1662             report.done.assert_called_with(expected, black.Changed.YES)
1663
1664     def test_gitignore_exclude(self) -> None:
1665         path = THIS_DIR / "data" / "include_exclude_tests"
1666         include = re.compile(r"\.pyi?$")
1667         exclude = re.compile(r"")
1668         report = black.Report()
1669         gitignore = PathSpec.from_lines(
1670             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1671         )
1672         sources: List[Path] = []
1673         expected = [
1674             Path(path / "b/dont_exclude/a.py"),
1675             Path(path / "b/dont_exclude/a.pyi"),
1676         ]
1677         this_abs = THIS_DIR.resolve()
1678         sources.extend(
1679             black.gen_python_files(
1680                 path.iterdir(),
1681                 this_abs,
1682                 include,
1683                 exclude,
1684                 None,
1685                 None,
1686                 report,
1687                 gitignore,
1688             )
1689         )
1690         self.assertEqual(sorted(expected), sorted(sources))
1691
1692     def test_empty_include(self) -> None:
1693         path = THIS_DIR / "data" / "include_exclude_tests"
1694         report = black.Report()
1695         gitignore = PathSpec.from_lines("gitwildmatch", [])
1696         empty = re.compile(r"")
1697         sources: List[Path] = []
1698         expected = [
1699             Path(path / "b/exclude/a.pie"),
1700             Path(path / "b/exclude/a.py"),
1701             Path(path / "b/exclude/a.pyi"),
1702             Path(path / "b/dont_exclude/a.pie"),
1703             Path(path / "b/dont_exclude/a.py"),
1704             Path(path / "b/dont_exclude/a.pyi"),
1705             Path(path / "b/.definitely_exclude/a.pie"),
1706             Path(path / "b/.definitely_exclude/a.py"),
1707             Path(path / "b/.definitely_exclude/a.pyi"),
1708         ]
1709         this_abs = THIS_DIR.resolve()
1710         sources.extend(
1711             black.gen_python_files(
1712                 path.iterdir(),
1713                 this_abs,
1714                 empty,
1715                 re.compile(black.DEFAULT_EXCLUDES),
1716                 None,
1717                 None,
1718                 report,
1719                 gitignore,
1720             )
1721         )
1722         self.assertEqual(sorted(expected), sorted(sources))
1723
1724     def test_extend_exclude(self) -> None:
1725         path = THIS_DIR / "data" / "include_exclude_tests"
1726         report = black.Report()
1727         gitignore = PathSpec.from_lines("gitwildmatch", [])
1728         sources: List[Path] = []
1729         expected = [
1730             Path(path / "b/exclude/a.py"),
1731             Path(path / "b/dont_exclude/a.py"),
1732         ]
1733         this_abs = THIS_DIR.resolve()
1734         sources.extend(
1735             black.gen_python_files(
1736                 path.iterdir(),
1737                 this_abs,
1738                 re.compile(black.DEFAULT_INCLUDES),
1739                 re.compile(r"\.pyi$"),
1740                 re.compile(r"\.definitely_exclude"),
1741                 None,
1742                 report,
1743                 gitignore,
1744             )
1745         )
1746         self.assertEqual(sorted(expected), sorted(sources))
1747
1748     def test_invalid_cli_regex(self) -> None:
1749         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1750             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1751
1752     def test_preserves_line_endings(self) -> None:
1753         with TemporaryDirectory() as workspace:
1754             test_file = Path(workspace) / "test.py"
1755             for nl in ["\n", "\r\n"]:
1756                 contents = nl.join(["def f(  ):", "    pass"])
1757                 test_file.write_bytes(contents.encode())
1758                 ff(test_file, write_back=black.WriteBack.YES)
1759                 updated_contents: bytes = test_file.read_bytes()
1760                 self.assertIn(nl.encode(), updated_contents)
1761                 if nl == "\n":
1762                     self.assertNotIn(b"\r\n", updated_contents)
1763
1764     def test_preserves_line_endings_via_stdin(self) -> None:
1765         for nl in ["\n", "\r\n"]:
1766             contents = nl.join(["def f(  ):", "    pass"])
1767             runner = BlackRunner()
1768             result = runner.invoke(
1769                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1770             )
1771             self.assertEqual(result.exit_code, 0)
1772             output = runner.stdout_bytes
1773             self.assertIn(nl.encode("utf8"), output)
1774             if nl == "\n":
1775                 self.assertNotIn(b"\r\n", output)
1776
1777     def test_assert_equivalent_different_asts(self) -> None:
1778         with self.assertRaises(AssertionError):
1779             black.assert_equivalent("{}", "None")
1780
1781     def test_symlink_out_of_root_directory(self) -> None:
1782         path = MagicMock()
1783         root = THIS_DIR.resolve()
1784         child = MagicMock()
1785         include = re.compile(black.DEFAULT_INCLUDES)
1786         exclude = re.compile(black.DEFAULT_EXCLUDES)
1787         report = black.Report()
1788         gitignore = PathSpec.from_lines("gitwildmatch", [])
1789         # `child` should behave like a symlink which resolved path is clearly
1790         # outside of the `root` directory.
1791         path.iterdir.return_value = [child]
1792         child.resolve.return_value = Path("/a/b/c")
1793         child.as_posix.return_value = "/a/b/c"
1794         child.is_symlink.return_value = True
1795         try:
1796             list(
1797                 black.gen_python_files(
1798                     path.iterdir(),
1799                     root,
1800                     include,
1801                     exclude,
1802                     None,
1803                     None,
1804                     report,
1805                     gitignore,
1806                 )
1807             )
1808         except ValueError as ve:
1809             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1810         path.iterdir.assert_called_once()
1811         child.resolve.assert_called_once()
1812         child.is_symlink.assert_called_once()
1813         # `child` should behave like a strange file which resolved path is clearly
1814         # outside of the `root` directory.
1815         child.is_symlink.return_value = False
1816         with self.assertRaises(ValueError):
1817             list(
1818                 black.gen_python_files(
1819                     path.iterdir(),
1820                     root,
1821                     include,
1822                     exclude,
1823                     None,
1824                     None,
1825                     report,
1826                     gitignore,
1827                 )
1828             )
1829         path.iterdir.assert_called()
1830         self.assertEqual(path.iterdir.call_count, 2)
1831         child.resolve.assert_called()
1832         self.assertEqual(child.resolve.call_count, 2)
1833         child.is_symlink.assert_called()
1834         self.assertEqual(child.is_symlink.call_count, 2)
1835
1836     def test_shhh_click(self) -> None:
1837         try:
1838             from click import _unicodefun  # type: ignore
1839         except ModuleNotFoundError:
1840             self.skipTest("Incompatible Click version")
1841         if not hasattr(_unicodefun, "_verify_python3_env"):
1842             self.skipTest("Incompatible Click version")
1843         # First, let's see if Click is crashing with a preferred ASCII charset.
1844         with patch("locale.getpreferredencoding") as gpe:
1845             gpe.return_value = "ASCII"
1846             with self.assertRaises(RuntimeError):
1847                 _unicodefun._verify_python3_env()
1848         # Now, let's silence Click...
1849         black.patch_click()
1850         # ...and confirm it's silent.
1851         with patch("locale.getpreferredencoding") as gpe:
1852             gpe.return_value = "ASCII"
1853             try:
1854                 _unicodefun._verify_python3_env()
1855             except RuntimeError as re:
1856                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1857
1858     def test_root_logger_not_used_directly(self) -> None:
1859         def fail(*args: Any, **kwargs: Any) -> None:
1860             self.fail("Record created with root logger")
1861
1862         with patch.multiple(
1863             logging.root,
1864             debug=fail,
1865             info=fail,
1866             warning=fail,
1867             error=fail,
1868             critical=fail,
1869             log=fail,
1870         ):
1871             ff(THIS_FILE)
1872
1873     def test_invalid_config_return_code(self) -> None:
1874         tmp_file = Path(black.dump_to_file())
1875         try:
1876             tmp_config = Path(black.dump_to_file())
1877             tmp_config.unlink()
1878             args = ["--config", str(tmp_config), str(tmp_file)]
1879             self.invokeBlack(args, exit_code=2, ignore_config=False)
1880         finally:
1881             tmp_file.unlink()
1882
1883     def test_parse_pyproject_toml(self) -> None:
1884         test_toml_file = THIS_DIR / "test.toml"
1885         config = black.parse_pyproject_toml(str(test_toml_file))
1886         self.assertEqual(config["verbose"], 1)
1887         self.assertEqual(config["check"], "no")
1888         self.assertEqual(config["diff"], "y")
1889         self.assertEqual(config["color"], True)
1890         self.assertEqual(config["line_length"], 79)
1891         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1892         self.assertEqual(config["exclude"], r"\.pyi?$")
1893         self.assertEqual(config["include"], r"\.py?$")
1894
1895     def test_read_pyproject_toml(self) -> None:
1896         test_toml_file = THIS_DIR / "test.toml"
1897         fake_ctx = FakeContext()
1898         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1899         config = fake_ctx.default_map
1900         self.assertEqual(config["verbose"], "1")
1901         self.assertEqual(config["check"], "no")
1902         self.assertEqual(config["diff"], "y")
1903         self.assertEqual(config["color"], "True")
1904         self.assertEqual(config["line_length"], "79")
1905         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1906         self.assertEqual(config["exclude"], r"\.pyi?$")
1907         self.assertEqual(config["include"], r"\.py?$")
1908
1909     def test_find_project_root(self) -> None:
1910         with TemporaryDirectory() as workspace:
1911             root = Path(workspace)
1912             test_dir = root / "test"
1913             test_dir.mkdir()
1914
1915             src_dir = root / "src"
1916             src_dir.mkdir()
1917
1918             root_pyproject = root / "pyproject.toml"
1919             root_pyproject.touch()
1920             src_pyproject = src_dir / "pyproject.toml"
1921             src_pyproject.touch()
1922             src_python = src_dir / "foo.py"
1923             src_python.touch()
1924
1925             self.assertEqual(
1926                 black.find_project_root((src_dir, test_dir)), root.resolve()
1927             )
1928             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1929             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1930
1931     @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
1932     def test_find_user_pyproject_toml_linux(self) -> None:
1933         if system() == "Windows":
1934             return
1935
1936         # Test if XDG_CONFIG_HOME is checked
1937         with TemporaryDirectory() as workspace:
1938             tmp_user_config = Path(workspace) / "black"
1939             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1940                 self.assertEqual(
1941                     black.find_user_pyproject_toml(), tmp_user_config.resolve()
1942                 )
1943
1944         # Test fallback for XDG_CONFIG_HOME
1945         with patch.dict("os.environ"):
1946             os.environ.pop("XDG_CONFIG_HOME", None)
1947             fallback_user_config = Path("~/.config").expanduser() / "black"
1948             self.assertEqual(
1949                 black.find_user_pyproject_toml(), fallback_user_config.resolve()
1950             )
1951
1952     def test_find_user_pyproject_toml_windows(self) -> None:
1953         if system() != "Windows":
1954             return
1955
1956         user_config_path = Path.home() / ".black"
1957         self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
1958
1959     def test_bpo_33660_workaround(self) -> None:
1960         if system() == "Windows":
1961             return
1962
1963         # https://bugs.python.org/issue33660
1964
1965         old_cwd = Path.cwd()
1966         try:
1967             root = Path("/")
1968             os.chdir(str(root))
1969             path = Path("workspace") / "project"
1970             report = black.Report(verbose=True)
1971             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1972             self.assertEqual(normalized_path, "workspace/project")
1973         finally:
1974             os.chdir(str(old_cwd))
1975
1976     def test_newline_comment_interaction(self) -> None:
1977         source = "class A:\\\r\n# type: ignore\n pass\n"
1978         output = black.format_str(source, mode=DEFAULT_MODE)
1979         black.assert_stable(source, output, mode=DEFAULT_MODE)
1980
1981     def test_bpo_2142_workaround(self) -> None:
1982
1983         # https://bugs.python.org/issue2142
1984
1985         source, _ = read_data("missing_final_newline.py")
1986         # read_data adds a trailing newline
1987         source = source.rstrip()
1988         expected, _ = read_data("missing_final_newline.diff")
1989         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1990         diff_header = re.compile(
1991             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1992             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1993         )
1994         try:
1995             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1996             self.assertEqual(result.exit_code, 0)
1997         finally:
1998             os.unlink(tmp_file)
1999         actual = result.output
2000         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2001         self.assertEqual(actual, expected)
2002
2003     @pytest.mark.python2
2004     def test_docstring_reformat_for_py27(self) -> None:
2005         """
2006         Check that stripping trailing whitespace from Python 2 docstrings
2007         doesn't trigger a "not equivalent to source" error
2008         """
2009         source = (
2010             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2011         )
2012         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2013
2014         result = CliRunner().invoke(
2015             black.main,
2016             ["-", "-q", "--target-version=py27"],
2017             input=BytesIO(source),
2018         )
2019
2020         self.assertEqual(result.exit_code, 0)
2021         actual = result.output
2022         self.assertFormatEqual(actual, expected)
2023
2024
2025 with open(black.__file__, "r", encoding="utf-8") as _bf:
2026     black_source_lines = _bf.readlines()
2027
2028
2029 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2030     """Show function calls `from black/__init__.py` as they happen.
2031
2032     Register this with `sys.settrace()` in a test you're debugging.
2033     """
2034     if event != "call":
2035         return tracefunc
2036
2037     stack = len(inspect.stack()) - 19
2038     stack *= 2
2039     filename = frame.f_code.co_filename
2040     lineno = frame.f_lineno
2041     func_sig_lineno = lineno - 1
2042     funcname = black_source_lines[func_sig_lineno].strip()
2043     while funcname.startswith("@"):
2044         func_sig_lineno += 1
2045         funcname = black_source_lines[func_sig_lineno].strip()
2046     if "black/__init__.py" in filename:
2047         print(f"{' ' * stack}{lineno}:{funcname}")
2048     return tracefunc
2049
2050
2051 if __name__ == "__main__":
2052     unittest.main(module="test_black")