]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Accept empty stdin (close #2337) (#2346)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 import io
10 from io import BytesIO
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     Callable,
21     Dict,
22     List,
23     Iterator,
24     TypeVar,
25 )
26 import pytest
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36 from black.cache import get_cache_file
37 from black.debug import DebugVisitor
38 from black.output import diff, color_diff
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     read_data,
48     DETERMINISTIC_HEADER,
49     BlackBaseTestCase,
50     DEFAULT_MODE,
51     fs,
52     ff,
53     dump_to_stderr,
54 )
55
56
57 THIS_FILE = Path(__file__)
58 PY36_VERSIONS = {
59     TargetVersion.PY36,
60     TargetVersion.PY37,
61     TargetVersion.PY38,
62     TargetVersion.PY39,
63 }
64 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
65 T = TypeVar("T")
66 R = TypeVar("R")
67
68 # Match the time output in a diff, but nothing else
69 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
70
71
72 @contextmanager
73 def cache_dir(exists: bool = True) -> Iterator[Path]:
74     with TemporaryDirectory() as workspace:
75         cache_dir = Path(workspace)
76         if not exists:
77             cache_dir = cache_dir / "new"
78         with patch("black.cache.CACHE_DIR", cache_dir):
79             yield cache_dir
80
81
82 @contextmanager
83 def event_loop() -> Iterator[None]:
84     policy = asyncio.get_event_loop_policy()
85     loop = policy.new_event_loop()
86     asyncio.set_event_loop(loop)
87     try:
88         yield
89
90     finally:
91         loop.close()
92
93
94 class FakeContext(click.Context):
95     """A fake click Context for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         self.default_map: Dict[str, Any] = {}
99
100
101 class FakeParameter(click.Parameter):
102     """A fake click Parameter for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         pass
106
107
108 class BlackRunner(CliRunner):
109     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
110
111     def __init__(self) -> None:
112         super().__init__(mix_stderr=False)
113
114
115 class BlackTestCase(BlackBaseTestCase):
116     def invokeBlack(
117         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
118     ) -> None:
119         runner = BlackRunner()
120         if ignore_config:
121             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
122         result = runner.invoke(black.main, args)
123         assert result.stdout_bytes is not None
124         assert result.stderr_bytes is not None
125         self.assertEqual(
126             result.exit_code,
127             exit_code,
128             msg=(
129                 f"Failed with args: {args}\n"
130                 f"stdout: {result.stdout_bytes.decode()!r}\n"
131                 f"stderr: {result.stderr_bytes.decode()!r}\n"
132                 f"exception: {result.exception}"
133             ),
134         )
135
136     @patch("black.dump_to_file", dump_to_stderr)
137     def test_empty(self) -> None:
138         source = expected = ""
139         actual = fs(source)
140         self.assertFormatEqual(expected, actual)
141         black.assert_equivalent(source, actual)
142         black.assert_stable(source, actual, DEFAULT_MODE)
143
144     def test_empty_ff(self) -> None:
145         expected = ""
146         tmp_file = Path(black.dump_to_file())
147         try:
148             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
149             with open(tmp_file, encoding="utf8") as f:
150                 actual = f.read()
151         finally:
152             os.unlink(tmp_file)
153         self.assertFormatEqual(expected, actual)
154
155     def test_piping(self) -> None:
156         source, expected = read_data("src/black/__init__", data=False)
157         result = BlackRunner().invoke(
158             black.main,
159             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
160             input=BytesIO(source.encode("utf8")),
161         )
162         self.assertEqual(result.exit_code, 0)
163         self.assertFormatEqual(expected, result.output)
164         if source != result.output:
165             black.assert_equivalent(source, result.output)
166             black.assert_stable(source, result.output, DEFAULT_MODE)
167
168     def test_piping_diff(self) -> None:
169         diff_header = re.compile(
170             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
171             r"\+\d\d\d\d"
172         )
173         source, _ = read_data("expression.py")
174         expected, _ = read_data("expression.diff")
175         config = THIS_DIR / "data" / "empty_pyproject.toml"
176         args = [
177             "-",
178             "--fast",
179             f"--line-length={black.DEFAULT_LINE_LENGTH}",
180             "--diff",
181             f"--config={config}",
182         ]
183         result = BlackRunner().invoke(
184             black.main, args, input=BytesIO(source.encode("utf8"))
185         )
186         self.assertEqual(result.exit_code, 0)
187         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
188         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
189         self.assertEqual(expected, actual)
190
191     def test_piping_diff_with_color(self) -> None:
192         source, _ = read_data("expression.py")
193         config = THIS_DIR / "data" / "empty_pyproject.toml"
194         args = [
195             "-",
196             "--fast",
197             f"--line-length={black.DEFAULT_LINE_LENGTH}",
198             "--diff",
199             "--color",
200             f"--config={config}",
201         ]
202         result = BlackRunner().invoke(
203             black.main, args, input=BytesIO(source.encode("utf8"))
204         )
205         actual = result.output
206         # Again, the contents are checked in a different test, so only look for colors.
207         self.assertIn("\033[1;37m", actual)
208         self.assertIn("\033[36m", actual)
209         self.assertIn("\033[32m", actual)
210         self.assertIn("\033[31m", actual)
211         self.assertIn("\033[0m", actual)
212
213     @patch("black.dump_to_file", dump_to_stderr)
214     def _test_wip(self) -> None:
215         source, expected = read_data("wip")
216         sys.settrace(tracefunc)
217         mode = replace(
218             DEFAULT_MODE,
219             experimental_string_processing=False,
220             target_versions={black.TargetVersion.PY38},
221         )
222         actual = fs(source, mode=mode)
223         sys.settrace(None)
224         self.assertFormatEqual(expected, actual)
225         black.assert_equivalent(source, actual)
226         black.assert_stable(source, actual, black.FileMode())
227
228     @unittest.expectedFailure
229     @patch("black.dump_to_file", dump_to_stderr)
230     def test_trailing_comma_optional_parens_stability1(self) -> None:
231         source, _expected = read_data("trailing_comma_optional_parens1")
232         actual = fs(source)
233         black.assert_stable(source, actual, DEFAULT_MODE)
234
235     @unittest.expectedFailure
236     @patch("black.dump_to_file", dump_to_stderr)
237     def test_trailing_comma_optional_parens_stability2(self) -> None:
238         source, _expected = read_data("trailing_comma_optional_parens2")
239         actual = fs(source)
240         black.assert_stable(source, actual, DEFAULT_MODE)
241
242     @unittest.expectedFailure
243     @patch("black.dump_to_file", dump_to_stderr)
244     def test_trailing_comma_optional_parens_stability3(self) -> None:
245         source, _expected = read_data("trailing_comma_optional_parens3")
246         actual = fs(source)
247         black.assert_stable(source, actual, DEFAULT_MODE)
248
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
251         source, _expected = read_data("trailing_comma_optional_parens1")
252         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
253         black.assert_stable(source, actual, DEFAULT_MODE)
254
255     @patch("black.dump_to_file", dump_to_stderr)
256     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
257         source, _expected = read_data("trailing_comma_optional_parens2")
258         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
259         black.assert_stable(source, actual, DEFAULT_MODE)
260
261     @patch("black.dump_to_file", dump_to_stderr)
262     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
263         source, _expected = read_data("trailing_comma_optional_parens3")
264         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
265         black.assert_stable(source, actual, DEFAULT_MODE)
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_pep_572(self) -> None:
269         source, expected = read_data("pep_572")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_stable(source, actual, DEFAULT_MODE)
273         if sys.version_info >= (3, 8):
274             black.assert_equivalent(source, actual)
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_pep_572_remove_parens(self) -> None:
278         source, expected = read_data("pep_572_remove_parens")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_stable(source, actual, DEFAULT_MODE)
282         if sys.version_info >= (3, 8):
283             black.assert_equivalent(source, actual)
284
285     @patch("black.dump_to_file", dump_to_stderr)
286     def test_pep_572_do_not_remove_parens(self) -> None:
287         source, expected = read_data("pep_572_do_not_remove_parens")
288         # the AST safety checks will fail, but that's expected, just make sure no
289         # parentheses are touched
290         actual = black.format_str(source, mode=DEFAULT_MODE)
291         self.assertFormatEqual(expected, actual)
292
293     def test_pep_572_version_detection(self) -> None:
294         source, _ = read_data("pep_572")
295         root = black.lib2to3_parse(source)
296         features = black.get_features_used(root)
297         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
298         versions = black.detect_target_versions(root)
299         self.assertIn(black.TargetVersion.PY38, versions)
300
301     def test_expression_ff(self) -> None:
302         source, expected = read_data("expression")
303         tmp_file = Path(black.dump_to_file(source))
304         try:
305             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
306             with open(tmp_file, encoding="utf8") as f:
307                 actual = f.read()
308         finally:
309             os.unlink(tmp_file)
310         self.assertFormatEqual(expected, actual)
311         with patch("black.dump_to_file", dump_to_stderr):
312             black.assert_equivalent(source, actual)
313             black.assert_stable(source, actual, DEFAULT_MODE)
314
315     def test_expression_diff(self) -> None:
316         source, _ = read_data("expression.py")
317         config = THIS_DIR / "data" / "empty_pyproject.toml"
318         expected, _ = read_data("expression.diff")
319         tmp_file = Path(black.dump_to_file(source))
320         diff_header = re.compile(
321             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
322             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
323         )
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
327             )
328             self.assertEqual(result.exit_code, 0)
329         finally:
330             os.unlink(tmp_file)
331         actual = result.output
332         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
333         if expected != actual:
334             dump = black.dump_to_file(actual)
335             msg = (
336                 "Expected diff isn't equal to the actual. If you made changes to"
337                 " expression.py and this is an anticipated difference, overwrite"
338                 f" tests/data/expression.diff with {dump}"
339             )
340             self.assertEqual(expected, actual, msg)
341
342     def test_expression_diff_with_color(self) -> None:
343         source, _ = read_data("expression.py")
344         config = THIS_DIR / "data" / "empty_pyproject.toml"
345         expected, _ = read_data("expression.diff")
346         tmp_file = Path(black.dump_to_file(source))
347         try:
348             result = BlackRunner().invoke(
349                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
350             )
351         finally:
352             os.unlink(tmp_file)
353         actual = result.output
354         # We check the contents of the diff in `test_expression_diff`. All
355         # we need to check here is that color codes exist in the result.
356         self.assertIn("\033[1;37m", actual)
357         self.assertIn("\033[36m", actual)
358         self.assertIn("\033[32m", actual)
359         self.assertIn("\033[31m", actual)
360         self.assertIn("\033[0m", actual)
361
362     @patch("black.dump_to_file", dump_to_stderr)
363     def test_pep_570(self) -> None:
364         source, expected = read_data("pep_570")
365         actual = fs(source)
366         self.assertFormatEqual(expected, actual)
367         black.assert_stable(source, actual, DEFAULT_MODE)
368         if sys.version_info >= (3, 8):
369             black.assert_equivalent(source, actual)
370
371     def test_detect_pos_only_arguments(self) -> None:
372         source, _ = read_data("pep_570")
373         root = black.lib2to3_parse(source)
374         features = black.get_features_used(root)
375         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
376         versions = black.detect_target_versions(root)
377         self.assertIn(black.TargetVersion.PY38, versions)
378
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_string_quotes(self) -> None:
381         source, expected = read_data("string_quotes")
382         mode = black.Mode(experimental_string_processing=True)
383         actual = fs(source, mode=mode)
384         self.assertFormatEqual(expected, actual)
385         black.assert_equivalent(source, actual)
386         black.assert_stable(source, actual, mode)
387         mode = replace(mode, string_normalization=False)
388         not_normalized = fs(source, mode=mode)
389         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
390         black.assert_equivalent(source, not_normalized)
391         black.assert_stable(source, not_normalized, mode=mode)
392
393     @patch("black.dump_to_file", dump_to_stderr)
394     def test_docstring_no_string_normalization(self) -> None:
395         """Like test_docstring but with string normalization off."""
396         source, expected = read_data("docstring_no_string_normalization")
397         mode = replace(DEFAULT_MODE, string_normalization=False)
398         actual = fs(source, mode=mode)
399         self.assertFormatEqual(expected, actual)
400         black.assert_equivalent(source, actual)
401         black.assert_stable(source, actual, mode)
402
403     def test_long_strings_flag_disabled(self) -> None:
404         """Tests for turning off the string processing logic."""
405         source, expected = read_data("long_strings_flag_disabled")
406         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
407         actual = fs(source, mode=mode)
408         self.assertFormatEqual(expected, actual)
409         black.assert_stable(expected, actual, mode)
410
411     @patch("black.dump_to_file", dump_to_stderr)
412     def test_numeric_literals(self) -> None:
413         source, expected = read_data("numeric_literals")
414         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
415         actual = fs(source, mode=mode)
416         self.assertFormatEqual(expected, actual)
417         black.assert_equivalent(source, actual)
418         black.assert_stable(source, actual, mode)
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_numeric_literals_ignoring_underscores(self) -> None:
422         source, expected = read_data("numeric_literals_skip_underscores")
423         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
424         actual = fs(source, mode=mode)
425         self.assertFormatEqual(expected, actual)
426         black.assert_equivalent(source, actual)
427         black.assert_stable(source, actual, mode)
428
429     def test_skip_magic_trailing_comma(self) -> None:
430         source, _ = read_data("expression.py")
431         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
432         tmp_file = Path(black.dump_to_file(source))
433         diff_header = re.compile(
434             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
435             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
436         )
437         try:
438             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
439             self.assertEqual(result.exit_code, 0)
440         finally:
441             os.unlink(tmp_file)
442         actual = result.output
443         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
444         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
445         if expected != actual:
446             dump = black.dump_to_file(actual)
447             msg = (
448                 "Expected diff isn't equal to the actual. If you made changes to"
449                 " expression.py and this is an anticipated difference, overwrite"
450                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
451             )
452             self.assertEqual(expected, actual, msg)
453
454     @pytest.mark.no_python2
455     def test_python2_should_fail_without_optional_install(self) -> None:
456         if sys.version_info < (3, 8):
457             self.skipTest(
458                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
459                 " able to parse Python 2 syntax without explicitly specifying the"
460                 " python2 extra"
461             )
462
463         source = "x = 1234l"
464         tmp_file = Path(black.dump_to_file(source))
465         try:
466             runner = BlackRunner()
467             result = runner.invoke(black.main, [str(tmp_file)])
468             self.assertEqual(result.exit_code, 123)
469         finally:
470             os.unlink(tmp_file)
471         assert result.stderr_bytes is not None
472         actual = (
473             result.stderr_bytes.decode()
474             .replace("\n", "")
475             .replace("\\n", "")
476             .replace("\\r", "")
477             .replace("\r", "")
478         )
479         msg = (
480             "The requested source code has invalid Python 3 syntax."
481             "If you are trying to format Python 2 files please reinstall Black"
482             " with the 'python2' extra: `python3 -m pip install black[python2]`."
483         )
484         self.assertIn(msg, actual)
485
486     @pytest.mark.python2
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_python2_print_function(self) -> None:
489         source, expected = read_data("python2_print_function")
490         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
491         actual = fs(source, mode=mode)
492         self.assertFormatEqual(expected, actual)
493         black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, mode)
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_stub(self) -> None:
498         mode = replace(DEFAULT_MODE, is_pyi=True)
499         source, expected = read_data("stub.pyi")
500         actual = fs(source, mode=mode)
501         self.assertFormatEqual(expected, actual)
502         black.assert_stable(source, actual, mode)
503
504     @patch("black.dump_to_file", dump_to_stderr)
505     def test_async_as_identifier(self) -> None:
506         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
507         source, expected = read_data("async_as_identifier")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         major, minor = sys.version_info[:2]
511         if major < 3 or (major <= 3 and minor < 7):
512             black.assert_equivalent(source, actual)
513         black.assert_stable(source, actual, DEFAULT_MODE)
514         # ensure black can parse this when the target is 3.6
515         self.invokeBlack([str(source_path), "--target-version", "py36"])
516         # but not on 3.7, because async/await is no longer an identifier
517         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
518
519     @patch("black.dump_to_file", dump_to_stderr)
520     def test_python37(self) -> None:
521         source_path = (THIS_DIR / "data" / "python37.py").resolve()
522         source, expected = read_data("python37")
523         actual = fs(source)
524         self.assertFormatEqual(expected, actual)
525         major, minor = sys.version_info[:2]
526         if major > 3 or (major == 3 and minor >= 7):
527             black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, DEFAULT_MODE)
529         # ensure black can parse this when the target is 3.7
530         self.invokeBlack([str(source_path), "--target-version", "py37"])
531         # but not on 3.6, because we use async as a reserved keyword
532         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
533
534     @patch("black.dump_to_file", dump_to_stderr)
535     def test_python38(self) -> None:
536         source, expected = read_data("python38")
537         actual = fs(source)
538         self.assertFormatEqual(expected, actual)
539         major, minor = sys.version_info[:2]
540         if major > 3 or (major == 3 and minor >= 8):
541             black.assert_equivalent(source, actual)
542         black.assert_stable(source, actual, DEFAULT_MODE)
543
544     @patch("black.dump_to_file", dump_to_stderr)
545     def test_python39(self) -> None:
546         source, expected = read_data("python39")
547         actual = fs(source)
548         self.assertFormatEqual(expected, actual)
549         major, minor = sys.version_info[:2]
550         if major > 3 or (major == 3 and minor >= 9):
551             black.assert_equivalent(source, actual)
552         black.assert_stable(source, actual, DEFAULT_MODE)
553
554     def test_tab_comment_indentation(self) -> None:
555         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
556         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
557         self.assertFormatEqual(contents_spc, fs(contents_spc))
558         self.assertFormatEqual(contents_spc, fs(contents_tab))
559
560         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
561         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
562         self.assertFormatEqual(contents_spc, fs(contents_spc))
563         self.assertFormatEqual(contents_spc, fs(contents_tab))
564
565         # mixed tabs and spaces (valid Python 2 code)
566         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
567         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
568         self.assertFormatEqual(contents_spc, fs(contents_spc))
569         self.assertFormatEqual(contents_spc, fs(contents_tab))
570
571         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
572         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
573         self.assertFormatEqual(contents_spc, fs(contents_spc))
574         self.assertFormatEqual(contents_spc, fs(contents_tab))
575
576     def test_report_verbose(self) -> None:
577         report = Report(verbose=True)
578         out_lines = []
579         err_lines = []
580
581         def out(msg: str, **kwargs: Any) -> None:
582             out_lines.append(msg)
583
584         def err(msg: str, **kwargs: Any) -> None:
585             err_lines.append(msg)
586
587         with patch("black.output._out", out), patch("black.output._err", err):
588             report.done(Path("f1"), black.Changed.NO)
589             self.assertEqual(len(out_lines), 1)
590             self.assertEqual(len(err_lines), 0)
591             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
592             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
593             self.assertEqual(report.return_code, 0)
594             report.done(Path("f2"), black.Changed.YES)
595             self.assertEqual(len(out_lines), 2)
596             self.assertEqual(len(err_lines), 0)
597             self.assertEqual(out_lines[-1], "reformatted f2")
598             self.assertEqual(
599                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
600             )
601             report.done(Path("f3"), black.Changed.CACHED)
602             self.assertEqual(len(out_lines), 3)
603             self.assertEqual(len(err_lines), 0)
604             self.assertEqual(
605                 out_lines[-1], "f3 wasn't modified on disk since last run."
606             )
607             self.assertEqual(
608                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
609             )
610             self.assertEqual(report.return_code, 0)
611             report.check = True
612             self.assertEqual(report.return_code, 1)
613             report.check = False
614             report.failed(Path("e1"), "boom")
615             self.assertEqual(len(out_lines), 3)
616             self.assertEqual(len(err_lines), 1)
617             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
621                 " reformat.",
622             )
623             self.assertEqual(report.return_code, 123)
624             report.done(Path("f3"), black.Changed.YES)
625             self.assertEqual(len(out_lines), 4)
626             self.assertEqual(len(err_lines), 1)
627             self.assertEqual(out_lines[-1], "reformatted f3")
628             self.assertEqual(
629                 unstyle(str(report)),
630                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
631                 " reformat.",
632             )
633             self.assertEqual(report.return_code, 123)
634             report.failed(Path("e2"), "boom")
635             self.assertEqual(len(out_lines), 4)
636             self.assertEqual(len(err_lines), 2)
637             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
638             self.assertEqual(
639                 unstyle(str(report)),
640                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
641                 " reformat.",
642             )
643             self.assertEqual(report.return_code, 123)
644             report.path_ignored(Path("wat"), "no match")
645             self.assertEqual(len(out_lines), 5)
646             self.assertEqual(len(err_lines), 2)
647             self.assertEqual(out_lines[-1], "wat ignored: no match")
648             self.assertEqual(
649                 unstyle(str(report)),
650                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
651                 " reformat.",
652             )
653             self.assertEqual(report.return_code, 123)
654             report.done(Path("f4"), black.Changed.NO)
655             self.assertEqual(len(out_lines), 6)
656             self.assertEqual(len(err_lines), 2)
657             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
658             self.assertEqual(
659                 unstyle(str(report)),
660                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
661                 " reformat.",
662             )
663             self.assertEqual(report.return_code, 123)
664             report.check = True
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
668                 " would fail to reformat.",
669             )
670             report.check = False
671             report.diff = True
672             self.assertEqual(
673                 unstyle(str(report)),
674                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
675                 " would fail to reformat.",
676             )
677
678     def test_report_quiet(self) -> None:
679         report = Report(quiet=True)
680         out_lines = []
681         err_lines = []
682
683         def out(msg: str, **kwargs: Any) -> None:
684             out_lines.append(msg)
685
686         def err(msg: str, **kwargs: Any) -> None:
687             err_lines.append(msg)
688
689         with patch("black.output._out", out), patch("black.output._err", err):
690             report.done(Path("f1"), black.Changed.NO)
691             self.assertEqual(len(out_lines), 0)
692             self.assertEqual(len(err_lines), 0)
693             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
694             self.assertEqual(report.return_code, 0)
695             report.done(Path("f2"), black.Changed.YES)
696             self.assertEqual(len(out_lines), 0)
697             self.assertEqual(len(err_lines), 0)
698             self.assertEqual(
699                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
700             )
701             report.done(Path("f3"), black.Changed.CACHED)
702             self.assertEqual(len(out_lines), 0)
703             self.assertEqual(len(err_lines), 0)
704             self.assertEqual(
705                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
706             )
707             self.assertEqual(report.return_code, 0)
708             report.check = True
709             self.assertEqual(report.return_code, 1)
710             report.check = False
711             report.failed(Path("e1"), "boom")
712             self.assertEqual(len(out_lines), 0)
713             self.assertEqual(len(err_lines), 1)
714             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
718                 " reformat.",
719             )
720             self.assertEqual(report.return_code, 123)
721             report.done(Path("f3"), black.Changed.YES)
722             self.assertEqual(len(out_lines), 0)
723             self.assertEqual(len(err_lines), 1)
724             self.assertEqual(
725                 unstyle(str(report)),
726                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
727                 " reformat.",
728             )
729             self.assertEqual(report.return_code, 123)
730             report.failed(Path("e2"), "boom")
731             self.assertEqual(len(out_lines), 0)
732             self.assertEqual(len(err_lines), 2)
733             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
734             self.assertEqual(
735                 unstyle(str(report)),
736                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
737                 " reformat.",
738             )
739             self.assertEqual(report.return_code, 123)
740             report.path_ignored(Path("wat"), "no match")
741             self.assertEqual(len(out_lines), 0)
742             self.assertEqual(len(err_lines), 2)
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
746                 " reformat.",
747             )
748             self.assertEqual(report.return_code, 123)
749             report.done(Path("f4"), black.Changed.NO)
750             self.assertEqual(len(out_lines), 0)
751             self.assertEqual(len(err_lines), 2)
752             self.assertEqual(
753                 unstyle(str(report)),
754                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
755                 " reformat.",
756             )
757             self.assertEqual(report.return_code, 123)
758             report.check = True
759             self.assertEqual(
760                 unstyle(str(report)),
761                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
762                 " would fail to reformat.",
763             )
764             report.check = False
765             report.diff = True
766             self.assertEqual(
767                 unstyle(str(report)),
768                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
769                 " would fail to reformat.",
770             )
771
772     def test_report_normal(self) -> None:
773         report = black.Report()
774         out_lines = []
775         err_lines = []
776
777         def out(msg: str, **kwargs: Any) -> None:
778             out_lines.append(msg)
779
780         def err(msg: str, **kwargs: Any) -> None:
781             err_lines.append(msg)
782
783         with patch("black.output._out", out), patch("black.output._err", err):
784             report.done(Path("f1"), black.Changed.NO)
785             self.assertEqual(len(out_lines), 0)
786             self.assertEqual(len(err_lines), 0)
787             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
788             self.assertEqual(report.return_code, 0)
789             report.done(Path("f2"), black.Changed.YES)
790             self.assertEqual(len(out_lines), 1)
791             self.assertEqual(len(err_lines), 0)
792             self.assertEqual(out_lines[-1], "reformatted f2")
793             self.assertEqual(
794                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
795             )
796             report.done(Path("f3"), black.Changed.CACHED)
797             self.assertEqual(len(out_lines), 1)
798             self.assertEqual(len(err_lines), 0)
799             self.assertEqual(out_lines[-1], "reformatted f2")
800             self.assertEqual(
801                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
802             )
803             self.assertEqual(report.return_code, 0)
804             report.check = True
805             self.assertEqual(report.return_code, 1)
806             report.check = False
807             report.failed(Path("e1"), "boom")
808             self.assertEqual(len(out_lines), 1)
809             self.assertEqual(len(err_lines), 1)
810             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
811             self.assertEqual(
812                 unstyle(str(report)),
813                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
814                 " reformat.",
815             )
816             self.assertEqual(report.return_code, 123)
817             report.done(Path("f3"), black.Changed.YES)
818             self.assertEqual(len(out_lines), 2)
819             self.assertEqual(len(err_lines), 1)
820             self.assertEqual(out_lines[-1], "reformatted f3")
821             self.assertEqual(
822                 unstyle(str(report)),
823                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
824                 " reformat.",
825             )
826             self.assertEqual(report.return_code, 123)
827             report.failed(Path("e2"), "boom")
828             self.assertEqual(len(out_lines), 2)
829             self.assertEqual(len(err_lines), 2)
830             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
831             self.assertEqual(
832                 unstyle(str(report)),
833                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
834                 " reformat.",
835             )
836             self.assertEqual(report.return_code, 123)
837             report.path_ignored(Path("wat"), "no match")
838             self.assertEqual(len(out_lines), 2)
839             self.assertEqual(len(err_lines), 2)
840             self.assertEqual(
841                 unstyle(str(report)),
842                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
843                 " reformat.",
844             )
845             self.assertEqual(report.return_code, 123)
846             report.done(Path("f4"), black.Changed.NO)
847             self.assertEqual(len(out_lines), 2)
848             self.assertEqual(len(err_lines), 2)
849             self.assertEqual(
850                 unstyle(str(report)),
851                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
852                 " reformat.",
853             )
854             self.assertEqual(report.return_code, 123)
855             report.check = True
856             self.assertEqual(
857                 unstyle(str(report)),
858                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
859                 " would fail to reformat.",
860             )
861             report.check = False
862             report.diff = True
863             self.assertEqual(
864                 unstyle(str(report)),
865                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
866                 " would fail to reformat.",
867             )
868
869     def test_lib2to3_parse(self) -> None:
870         with self.assertRaises(black.InvalidInput):
871             black.lib2to3_parse("invalid syntax")
872
873         straddling = "x + y"
874         black.lib2to3_parse(straddling)
875         black.lib2to3_parse(straddling, {TargetVersion.PY27})
876         black.lib2to3_parse(straddling, {TargetVersion.PY36})
877         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
878
879         py2_only = "print x"
880         black.lib2to3_parse(py2_only)
881         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
882         with self.assertRaises(black.InvalidInput):
883             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
884         with self.assertRaises(black.InvalidInput):
885             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
886
887         py3_only = "exec(x, end=y)"
888         black.lib2to3_parse(py3_only)
889         with self.assertRaises(black.InvalidInput):
890             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
891         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
892         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
893
894     def test_get_features_used_decorator(self) -> None:
895         # Test the feature detection of new decorator syntax
896         # since this makes some test cases of test_get_features_used()
897         # fails if it fails, this is tested first so that a useful case
898         # is identified
899         simples, relaxed = read_data("decorators")
900         # skip explanation comments at the top of the file
901         for simple_test in simples.split("##")[1:]:
902             node = black.lib2to3_parse(simple_test)
903             decorator = str(node.children[0].children[0]).strip()
904             self.assertNotIn(
905                 Feature.RELAXED_DECORATORS,
906                 black.get_features_used(node),
907                 msg=(
908                     f"decorator '{decorator}' follows python<=3.8 syntax"
909                     "but is detected as 3.9+"
910                     # f"The full node is\n{node!r}"
911                 ),
912             )
913         # skip the '# output' comment at the top of the output part
914         for relaxed_test in relaxed.split("##")[1:]:
915             node = black.lib2to3_parse(relaxed_test)
916             decorator = str(node.children[0].children[0]).strip()
917             self.assertIn(
918                 Feature.RELAXED_DECORATORS,
919                 black.get_features_used(node),
920                 msg=(
921                     f"decorator '{decorator}' uses python3.9+ syntax"
922                     "but is detected as python<=3.8"
923                     # f"The full node is\n{node!r}"
924                 ),
925             )
926
927     def test_get_features_used(self) -> None:
928         node = black.lib2to3_parse("def f(*, arg): ...\n")
929         self.assertEqual(black.get_features_used(node), set())
930         node = black.lib2to3_parse("def f(*, arg,): ...\n")
931         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
932         node = black.lib2to3_parse("f(*arg,)\n")
933         self.assertEqual(
934             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
935         )
936         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
937         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
938         node = black.lib2to3_parse("123_456\n")
939         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
940         node = black.lib2to3_parse("123456\n")
941         self.assertEqual(black.get_features_used(node), set())
942         source, expected = read_data("function")
943         node = black.lib2to3_parse(source)
944         expected_features = {
945             Feature.TRAILING_COMMA_IN_CALL,
946             Feature.TRAILING_COMMA_IN_DEF,
947             Feature.F_STRINGS,
948         }
949         self.assertEqual(black.get_features_used(node), expected_features)
950         node = black.lib2to3_parse(expected)
951         self.assertEqual(black.get_features_used(node), expected_features)
952         source, expected = read_data("expression")
953         node = black.lib2to3_parse(source)
954         self.assertEqual(black.get_features_used(node), set())
955         node = black.lib2to3_parse(expected)
956         self.assertEqual(black.get_features_used(node), set())
957
958     def test_get_future_imports(self) -> None:
959         node = black.lib2to3_parse("\n")
960         self.assertEqual(set(), black.get_future_imports(node))
961         node = black.lib2to3_parse("from __future__ import black\n")
962         self.assertEqual({"black"}, black.get_future_imports(node))
963         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
964         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
965         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
966         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
967         node = black.lib2to3_parse(
968             "from __future__ import multiple\nfrom __future__ import imports\n"
969         )
970         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
971         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
972         self.assertEqual({"black"}, black.get_future_imports(node))
973         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
974         self.assertEqual({"black"}, black.get_future_imports(node))
975         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
976         self.assertEqual(set(), black.get_future_imports(node))
977         node = black.lib2to3_parse("from some.module import black\n")
978         self.assertEqual(set(), black.get_future_imports(node))
979         node = black.lib2to3_parse(
980             "from __future__ import unicode_literals as _unicode_literals"
981         )
982         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
983         node = black.lib2to3_parse(
984             "from __future__ import unicode_literals as _lol, print"
985         )
986         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
987
988     def test_debug_visitor(self) -> None:
989         source, _ = read_data("debug_visitor.py")
990         expected, _ = read_data("debug_visitor.out")
991         out_lines = []
992         err_lines = []
993
994         def out(msg: str, **kwargs: Any) -> None:
995             out_lines.append(msg)
996
997         def err(msg: str, **kwargs: Any) -> None:
998             err_lines.append(msg)
999
1000         with patch("black.debug.out", out):
1001             DebugVisitor.show(source)
1002         actual = "\n".join(out_lines) + "\n"
1003         log_name = ""
1004         if expected != actual:
1005             log_name = black.dump_to_file(*out_lines)
1006         self.assertEqual(
1007             expected,
1008             actual,
1009             f"AST print out is different. Actual version dumped to {log_name}",
1010         )
1011
1012     def test_format_file_contents(self) -> None:
1013         empty = ""
1014         mode = DEFAULT_MODE
1015         with self.assertRaises(black.NothingChanged):
1016             black.format_file_contents(empty, mode=mode, fast=False)
1017         just_nl = "\n"
1018         with self.assertRaises(black.NothingChanged):
1019             black.format_file_contents(just_nl, mode=mode, fast=False)
1020         same = "j = [1, 2, 3]\n"
1021         with self.assertRaises(black.NothingChanged):
1022             black.format_file_contents(same, mode=mode, fast=False)
1023         different = "j = [1,2,3]"
1024         expected = same
1025         actual = black.format_file_contents(different, mode=mode, fast=False)
1026         self.assertEqual(expected, actual)
1027         invalid = "return if you can"
1028         with self.assertRaises(black.InvalidInput) as e:
1029             black.format_file_contents(invalid, mode=mode, fast=False)
1030         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1031
1032     def test_endmarker(self) -> None:
1033         n = black.lib2to3_parse("\n")
1034         self.assertEqual(n.type, black.syms.file_input)
1035         self.assertEqual(len(n.children), 1)
1036         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1037
1038     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1039     def test_assertFormatEqual(self) -> None:
1040         out_lines = []
1041         err_lines = []
1042
1043         def out(msg: str, **kwargs: Any) -> None:
1044             out_lines.append(msg)
1045
1046         def err(msg: str, **kwargs: Any) -> None:
1047             err_lines.append(msg)
1048
1049         with patch("black.output._out", out), patch("black.output._err", err):
1050             with self.assertRaises(AssertionError):
1051                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1052
1053         out_str = "".join(out_lines)
1054         self.assertTrue("Expected tree:" in out_str)
1055         self.assertTrue("Actual tree:" in out_str)
1056         self.assertEqual("".join(err_lines), "")
1057
1058     def test_cache_broken_file(self) -> None:
1059         mode = DEFAULT_MODE
1060         with cache_dir() as workspace:
1061             cache_file = get_cache_file(mode)
1062             with cache_file.open("w") as fobj:
1063                 fobj.write("this is not a pickle")
1064             self.assertEqual(black.read_cache(mode), {})
1065             src = (workspace / "test.py").resolve()
1066             with src.open("w") as fobj:
1067                 fobj.write("print('hello')")
1068             self.invokeBlack([str(src)])
1069             cache = black.read_cache(mode)
1070             self.assertIn(str(src), cache)
1071
1072     def test_cache_single_file_already_cached(self) -> None:
1073         mode = DEFAULT_MODE
1074         with cache_dir() as workspace:
1075             src = (workspace / "test.py").resolve()
1076             with src.open("w") as fobj:
1077                 fobj.write("print('hello')")
1078             black.write_cache({}, [src], mode)
1079             self.invokeBlack([str(src)])
1080             with src.open("r") as fobj:
1081                 self.assertEqual(fobj.read(), "print('hello')")
1082
1083     @event_loop()
1084     def test_cache_multiple_files(self) -> None:
1085         mode = DEFAULT_MODE
1086         with cache_dir() as workspace, patch(
1087             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1088         ):
1089             one = (workspace / "one.py").resolve()
1090             with one.open("w") as fobj:
1091                 fobj.write("print('hello')")
1092             two = (workspace / "two.py").resolve()
1093             with two.open("w") as fobj:
1094                 fobj.write("print('hello')")
1095             black.write_cache({}, [one], mode)
1096             self.invokeBlack([str(workspace)])
1097             with one.open("r") as fobj:
1098                 self.assertEqual(fobj.read(), "print('hello')")
1099             with two.open("r") as fobj:
1100                 self.assertEqual(fobj.read(), 'print("hello")\n')
1101             cache = black.read_cache(mode)
1102             self.assertIn(str(one), cache)
1103             self.assertIn(str(two), cache)
1104
1105     def test_no_cache_when_writeback_diff(self) -> None:
1106         mode = DEFAULT_MODE
1107         with cache_dir() as workspace:
1108             src = (workspace / "test.py").resolve()
1109             with src.open("w") as fobj:
1110                 fobj.write("print('hello')")
1111             with patch("black.read_cache") as read_cache, patch(
1112                 "black.write_cache"
1113             ) as write_cache:
1114                 self.invokeBlack([str(src), "--diff"])
1115                 cache_file = get_cache_file(mode)
1116                 self.assertFalse(cache_file.exists())
1117                 write_cache.assert_not_called()
1118                 read_cache.assert_not_called()
1119
1120     def test_no_cache_when_writeback_color_diff(self) -> None:
1121         mode = DEFAULT_MODE
1122         with cache_dir() as workspace:
1123             src = (workspace / "test.py").resolve()
1124             with src.open("w") as fobj:
1125                 fobj.write("print('hello')")
1126             with patch("black.read_cache") as read_cache, patch(
1127                 "black.write_cache"
1128             ) as write_cache:
1129                 self.invokeBlack([str(src), "--diff", "--color"])
1130                 cache_file = get_cache_file(mode)
1131                 self.assertFalse(cache_file.exists())
1132                 write_cache.assert_not_called()
1133                 read_cache.assert_not_called()
1134
1135     @event_loop()
1136     def test_output_locking_when_writeback_diff(self) -> None:
1137         with cache_dir() as workspace:
1138             for tag in range(0, 4):
1139                 src = (workspace / f"test{tag}.py").resolve()
1140                 with src.open("w") as fobj:
1141                     fobj.write("print('hello')")
1142             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1143                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1144                 # this isn't quite doing what we want, but if it _isn't_
1145                 # called then we cannot be using the lock it provides
1146                 mgr.assert_called()
1147
1148     @event_loop()
1149     def test_output_locking_when_writeback_color_diff(self) -> None:
1150         with cache_dir() as workspace:
1151             for tag in range(0, 4):
1152                 src = (workspace / f"test{tag}.py").resolve()
1153                 with src.open("w") as fobj:
1154                     fobj.write("print('hello')")
1155             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1156                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1157                 # this isn't quite doing what we want, but if it _isn't_
1158                 # called then we cannot be using the lock it provides
1159                 mgr.assert_called()
1160
1161     def test_no_cache_when_stdin(self) -> None:
1162         mode = DEFAULT_MODE
1163         with cache_dir():
1164             result = CliRunner().invoke(
1165                 black.main, ["-"], input=BytesIO(b"print('hello')")
1166             )
1167             self.assertEqual(result.exit_code, 0)
1168             cache_file = get_cache_file(mode)
1169             self.assertFalse(cache_file.exists())
1170
1171     def test_read_cache_no_cachefile(self) -> None:
1172         mode = DEFAULT_MODE
1173         with cache_dir():
1174             self.assertEqual(black.read_cache(mode), {})
1175
1176     def test_write_cache_read_cache(self) -> None:
1177         mode = DEFAULT_MODE
1178         with cache_dir() as workspace:
1179             src = (workspace / "test.py").resolve()
1180             src.touch()
1181             black.write_cache({}, [src], mode)
1182             cache = black.read_cache(mode)
1183             self.assertIn(str(src), cache)
1184             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1185
1186     def test_filter_cached(self) -> None:
1187         with TemporaryDirectory() as workspace:
1188             path = Path(workspace)
1189             uncached = (path / "uncached").resolve()
1190             cached = (path / "cached").resolve()
1191             cached_but_changed = (path / "changed").resolve()
1192             uncached.touch()
1193             cached.touch()
1194             cached_but_changed.touch()
1195             cache = {
1196                 str(cached): black.get_cache_info(cached),
1197                 str(cached_but_changed): (0.0, 0),
1198             }
1199             todo, done = black.filter_cached(
1200                 cache, {uncached, cached, cached_but_changed}
1201             )
1202             self.assertEqual(todo, {uncached, cached_but_changed})
1203             self.assertEqual(done, {cached})
1204
1205     def test_write_cache_creates_directory_if_needed(self) -> None:
1206         mode = DEFAULT_MODE
1207         with cache_dir(exists=False) as workspace:
1208             self.assertFalse(workspace.exists())
1209             black.write_cache({}, [], mode)
1210             self.assertTrue(workspace.exists())
1211
1212     @event_loop()
1213     def test_failed_formatting_does_not_get_cached(self) -> None:
1214         mode = DEFAULT_MODE
1215         with cache_dir() as workspace, patch(
1216             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1217         ):
1218             failing = (workspace / "failing.py").resolve()
1219             with failing.open("w") as fobj:
1220                 fobj.write("not actually python")
1221             clean = (workspace / "clean.py").resolve()
1222             with clean.open("w") as fobj:
1223                 fobj.write('print("hello")\n')
1224             self.invokeBlack([str(workspace)], exit_code=123)
1225             cache = black.read_cache(mode)
1226             self.assertNotIn(str(failing), cache)
1227             self.assertIn(str(clean), cache)
1228
1229     def test_write_cache_write_fail(self) -> None:
1230         mode = DEFAULT_MODE
1231         with cache_dir(), patch.object(Path, "open") as mock:
1232             mock.side_effect = OSError
1233             black.write_cache({}, [], mode)
1234
1235     @event_loop()
1236     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1237     def test_works_in_mono_process_only_environment(self) -> None:
1238         with cache_dir() as workspace:
1239             for f in [
1240                 (workspace / "one.py").resolve(),
1241                 (workspace / "two.py").resolve(),
1242             ]:
1243                 f.write_text('print("hello")\n')
1244             self.invokeBlack([str(workspace)])
1245
1246     @event_loop()
1247     def test_check_diff_use_together(self) -> None:
1248         with cache_dir():
1249             # Files which will be reformatted.
1250             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1251             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1252             # Files which will not be reformatted.
1253             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1254             self.invokeBlack([str(src2), "--diff", "--check"])
1255             # Multi file command.
1256             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1257
1258     def test_no_files(self) -> None:
1259         with cache_dir():
1260             # Without an argument, black exits with error code 0.
1261             self.invokeBlack([])
1262
1263     def test_broken_symlink(self) -> None:
1264         with cache_dir() as workspace:
1265             symlink = workspace / "broken_link.py"
1266             try:
1267                 symlink.symlink_to("nonexistent.py")
1268             except OSError as e:
1269                 self.skipTest(f"Can't create symlinks: {e}")
1270             self.invokeBlack([str(workspace.resolve())])
1271
1272     def test_read_cache_line_lengths(self) -> None:
1273         mode = DEFAULT_MODE
1274         short_mode = replace(DEFAULT_MODE, line_length=1)
1275         with cache_dir() as workspace:
1276             path = (workspace / "file.py").resolve()
1277             path.touch()
1278             black.write_cache({}, [path], mode)
1279             one = black.read_cache(mode)
1280             self.assertIn(str(path), one)
1281             two = black.read_cache(short_mode)
1282             self.assertNotIn(str(path), two)
1283
1284     def test_single_file_force_pyi(self) -> None:
1285         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1286         contents, expected = read_data("force_pyi")
1287         with cache_dir() as workspace:
1288             path = (workspace / "file.py").resolve()
1289             with open(path, "w") as fh:
1290                 fh.write(contents)
1291             self.invokeBlack([str(path), "--pyi"])
1292             with open(path, "r") as fh:
1293                 actual = fh.read()
1294             # verify cache with --pyi is separate
1295             pyi_cache = black.read_cache(pyi_mode)
1296             self.assertIn(str(path), pyi_cache)
1297             normal_cache = black.read_cache(DEFAULT_MODE)
1298             self.assertNotIn(str(path), normal_cache)
1299         self.assertFormatEqual(expected, actual)
1300         black.assert_equivalent(contents, actual)
1301         black.assert_stable(contents, actual, pyi_mode)
1302
1303     @event_loop()
1304     def test_multi_file_force_pyi(self) -> None:
1305         reg_mode = DEFAULT_MODE
1306         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1307         contents, expected = read_data("force_pyi")
1308         with cache_dir() as workspace:
1309             paths = [
1310                 (workspace / "file1.py").resolve(),
1311                 (workspace / "file2.py").resolve(),
1312             ]
1313             for path in paths:
1314                 with open(path, "w") as fh:
1315                     fh.write(contents)
1316             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1317             for path in paths:
1318                 with open(path, "r") as fh:
1319                     actual = fh.read()
1320                 self.assertEqual(actual, expected)
1321             # verify cache with --pyi is separate
1322             pyi_cache = black.read_cache(pyi_mode)
1323             normal_cache = black.read_cache(reg_mode)
1324             for path in paths:
1325                 self.assertIn(str(path), pyi_cache)
1326                 self.assertNotIn(str(path), normal_cache)
1327
1328     def test_pipe_force_pyi(self) -> None:
1329         source, expected = read_data("force_pyi")
1330         result = CliRunner().invoke(
1331             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1332         )
1333         self.assertEqual(result.exit_code, 0)
1334         actual = result.output
1335         self.assertFormatEqual(actual, expected)
1336
1337     def test_single_file_force_py36(self) -> None:
1338         reg_mode = DEFAULT_MODE
1339         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1340         source, expected = read_data("force_py36")
1341         with cache_dir() as workspace:
1342             path = (workspace / "file.py").resolve()
1343             with open(path, "w") as fh:
1344                 fh.write(source)
1345             self.invokeBlack([str(path), *PY36_ARGS])
1346             with open(path, "r") as fh:
1347                 actual = fh.read()
1348             # verify cache with --target-version is separate
1349             py36_cache = black.read_cache(py36_mode)
1350             self.assertIn(str(path), py36_cache)
1351             normal_cache = black.read_cache(reg_mode)
1352             self.assertNotIn(str(path), normal_cache)
1353         self.assertEqual(actual, expected)
1354
1355     @event_loop()
1356     def test_multi_file_force_py36(self) -> None:
1357         reg_mode = DEFAULT_MODE
1358         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1359         source, expected = read_data("force_py36")
1360         with cache_dir() as workspace:
1361             paths = [
1362                 (workspace / "file1.py").resolve(),
1363                 (workspace / "file2.py").resolve(),
1364             ]
1365             for path in paths:
1366                 with open(path, "w") as fh:
1367                     fh.write(source)
1368             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1369             for path in paths:
1370                 with open(path, "r") as fh:
1371                     actual = fh.read()
1372                 self.assertEqual(actual, expected)
1373             # verify cache with --target-version is separate
1374             pyi_cache = black.read_cache(py36_mode)
1375             normal_cache = black.read_cache(reg_mode)
1376             for path in paths:
1377                 self.assertIn(str(path), pyi_cache)
1378                 self.assertNotIn(str(path), normal_cache)
1379
1380     def test_pipe_force_py36(self) -> None:
1381         source, expected = read_data("force_py36")
1382         result = CliRunner().invoke(
1383             black.main,
1384             ["-", "-q", "--target-version=py36"],
1385             input=BytesIO(source.encode("utf8")),
1386         )
1387         self.assertEqual(result.exit_code, 0)
1388         actual = result.output
1389         self.assertFormatEqual(actual, expected)
1390
1391     def test_include_exclude(self) -> None:
1392         path = THIS_DIR / "data" / "include_exclude_tests"
1393         include = re.compile(r"\.pyi?$")
1394         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1395         report = black.Report()
1396         gitignore = PathSpec.from_lines("gitwildmatch", [])
1397         sources: List[Path] = []
1398         expected = [
1399             Path(path / "b/dont_exclude/a.py"),
1400             Path(path / "b/dont_exclude/a.pyi"),
1401         ]
1402         this_abs = THIS_DIR.resolve()
1403         sources.extend(
1404             black.gen_python_files(
1405                 path.iterdir(),
1406                 this_abs,
1407                 include,
1408                 exclude,
1409                 None,
1410                 None,
1411                 report,
1412                 gitignore,
1413             )
1414         )
1415         self.assertEqual(sorted(expected), sorted(sources))
1416
1417     def test_gitignore_used_as_default(self) -> None:
1418         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1419         include = re.compile(r"\.pyi?$")
1420         extend_exclude = re.compile(r"/exclude/")
1421         src = str(path / "b/")
1422         report = black.Report()
1423         expected: List[Path] = [
1424             path / "b/.definitely_exclude/a.py",
1425             path / "b/.definitely_exclude/a.pyi",
1426         ]
1427         sources = list(
1428             black.get_sources(
1429                 ctx=FakeContext(),
1430                 src=(src,),
1431                 quiet=True,
1432                 verbose=False,
1433                 include=include,
1434                 exclude=None,
1435                 extend_exclude=extend_exclude,
1436                 force_exclude=None,
1437                 report=report,
1438                 stdin_filename=None,
1439             )
1440         )
1441         self.assertEqual(sorted(expected), sorted(sources))
1442
1443     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1444     def test_exclude_for_issue_1572(self) -> None:
1445         # Exclude shouldn't touch files that were explicitly given to Black through the
1446         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1447         # https://github.com/psf/black/issues/1572
1448         path = THIS_DIR / "data" / "include_exclude_tests"
1449         include = ""
1450         exclude = r"/exclude/|a\.py"
1451         src = str(path / "b/exclude/a.py")
1452         report = black.Report()
1453         expected = [Path(path / "b/exclude/a.py")]
1454         sources = list(
1455             black.get_sources(
1456                 ctx=FakeContext(),
1457                 src=(src,),
1458                 quiet=True,
1459                 verbose=False,
1460                 include=re.compile(include),
1461                 exclude=re.compile(exclude),
1462                 extend_exclude=None,
1463                 force_exclude=None,
1464                 report=report,
1465                 stdin_filename=None,
1466             )
1467         )
1468         self.assertEqual(sorted(expected), sorted(sources))
1469
1470     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1471     def test_get_sources_with_stdin(self) -> None:
1472         include = ""
1473         exclude = r"/exclude/|a\.py"
1474         src = "-"
1475         report = black.Report()
1476         expected = [Path("-")]
1477         sources = list(
1478             black.get_sources(
1479                 ctx=FakeContext(),
1480                 src=(src,),
1481                 quiet=True,
1482                 verbose=False,
1483                 include=re.compile(include),
1484                 exclude=re.compile(exclude),
1485                 extend_exclude=None,
1486                 force_exclude=None,
1487                 report=report,
1488                 stdin_filename=None,
1489             )
1490         )
1491         self.assertEqual(sorted(expected), sorted(sources))
1492
1493     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1494     def test_get_sources_with_stdin_filename(self) -> None:
1495         include = ""
1496         exclude = r"/exclude/|a\.py"
1497         src = "-"
1498         report = black.Report()
1499         stdin_filename = str(THIS_DIR / "data/collections.py")
1500         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1501         sources = list(
1502             black.get_sources(
1503                 ctx=FakeContext(),
1504                 src=(src,),
1505                 quiet=True,
1506                 verbose=False,
1507                 include=re.compile(include),
1508                 exclude=re.compile(exclude),
1509                 extend_exclude=None,
1510                 force_exclude=None,
1511                 report=report,
1512                 stdin_filename=stdin_filename,
1513             )
1514         )
1515         self.assertEqual(sorted(expected), sorted(sources))
1516
1517     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1518     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1519         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1520         # file being passed directly. This is the same as
1521         # test_exclude_for_issue_1572
1522         path = THIS_DIR / "data" / "include_exclude_tests"
1523         include = ""
1524         exclude = r"/exclude/|a\.py"
1525         src = "-"
1526         report = black.Report()
1527         stdin_filename = str(path / "b/exclude/a.py")
1528         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1529         sources = list(
1530             black.get_sources(
1531                 ctx=FakeContext(),
1532                 src=(src,),
1533                 quiet=True,
1534                 verbose=False,
1535                 include=re.compile(include),
1536                 exclude=re.compile(exclude),
1537                 extend_exclude=None,
1538                 force_exclude=None,
1539                 report=report,
1540                 stdin_filename=stdin_filename,
1541             )
1542         )
1543         self.assertEqual(sorted(expected), sorted(sources))
1544
1545     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1546     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1547         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1548         # file being passed directly. This is the same as
1549         # test_exclude_for_issue_1572
1550         path = THIS_DIR / "data" / "include_exclude_tests"
1551         include = ""
1552         extend_exclude = r"/exclude/|a\.py"
1553         src = "-"
1554         report = black.Report()
1555         stdin_filename = str(path / "b/exclude/a.py")
1556         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1557         sources = list(
1558             black.get_sources(
1559                 ctx=FakeContext(),
1560                 src=(src,),
1561                 quiet=True,
1562                 verbose=False,
1563                 include=re.compile(include),
1564                 exclude=re.compile(""),
1565                 extend_exclude=re.compile(extend_exclude),
1566                 force_exclude=None,
1567                 report=report,
1568                 stdin_filename=stdin_filename,
1569             )
1570         )
1571         self.assertEqual(sorted(expected), sorted(sources))
1572
1573     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1574     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1575         # Force exclude should exclude the file when passing it through
1576         # stdin_filename
1577         path = THIS_DIR / "data" / "include_exclude_tests"
1578         include = ""
1579         force_exclude = r"/exclude/|a\.py"
1580         src = "-"
1581         report = black.Report()
1582         stdin_filename = str(path / "b/exclude/a.py")
1583         sources = list(
1584             black.get_sources(
1585                 ctx=FakeContext(),
1586                 src=(src,),
1587                 quiet=True,
1588                 verbose=False,
1589                 include=re.compile(include),
1590                 exclude=re.compile(""),
1591                 extend_exclude=None,
1592                 force_exclude=re.compile(force_exclude),
1593                 report=report,
1594                 stdin_filename=stdin_filename,
1595             )
1596         )
1597         self.assertEqual([], sorted(sources))
1598
1599     def test_reformat_one_with_stdin(self) -> None:
1600         with patch(
1601             "black.format_stdin_to_stdout",
1602             return_value=lambda *args, **kwargs: black.Changed.YES,
1603         ) as fsts:
1604             report = MagicMock()
1605             path = Path("-")
1606             black.reformat_one(
1607                 path,
1608                 fast=True,
1609                 write_back=black.WriteBack.YES,
1610                 mode=DEFAULT_MODE,
1611                 report=report,
1612             )
1613             fsts.assert_called_once()
1614             report.done.assert_called_with(path, black.Changed.YES)
1615
1616     def test_reformat_one_with_stdin_filename(self) -> None:
1617         with patch(
1618             "black.format_stdin_to_stdout",
1619             return_value=lambda *args, **kwargs: black.Changed.YES,
1620         ) as fsts:
1621             report = MagicMock()
1622             p = "foo.py"
1623             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1624             expected = Path(p)
1625             black.reformat_one(
1626                 path,
1627                 fast=True,
1628                 write_back=black.WriteBack.YES,
1629                 mode=DEFAULT_MODE,
1630                 report=report,
1631             )
1632             fsts.assert_called_once_with(
1633                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1634             )
1635             # __BLACK_STDIN_FILENAME__ should have been stripped
1636             report.done.assert_called_with(expected, black.Changed.YES)
1637
1638     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1639         with patch(
1640             "black.format_stdin_to_stdout",
1641             return_value=lambda *args, **kwargs: black.Changed.YES,
1642         ) as fsts:
1643             report = MagicMock()
1644             p = "foo.pyi"
1645             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1646             expected = Path(p)
1647             black.reformat_one(
1648                 path,
1649                 fast=True,
1650                 write_back=black.WriteBack.YES,
1651                 mode=DEFAULT_MODE,
1652                 report=report,
1653             )
1654             fsts.assert_called_once_with(
1655                 fast=True,
1656                 write_back=black.WriteBack.YES,
1657                 mode=replace(DEFAULT_MODE, is_pyi=True),
1658             )
1659             # __BLACK_STDIN_FILENAME__ should have been stripped
1660             report.done.assert_called_with(expected, black.Changed.YES)
1661
1662     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1663         with patch(
1664             "black.format_stdin_to_stdout",
1665             return_value=lambda *args, **kwargs: black.Changed.YES,
1666         ) as fsts:
1667             report = MagicMock()
1668             # Even with an existing file, since we are forcing stdin, black
1669             # should output to stdout and not modify the file inplace
1670             p = Path(str(THIS_DIR / "data/collections.py"))
1671             # Make sure is_file actually returns True
1672             self.assertTrue(p.is_file())
1673             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1674             expected = Path(p)
1675             black.reformat_one(
1676                 path,
1677                 fast=True,
1678                 write_back=black.WriteBack.YES,
1679                 mode=DEFAULT_MODE,
1680                 report=report,
1681             )
1682             fsts.assert_called_once()
1683             # __BLACK_STDIN_FILENAME__ should have been stripped
1684             report.done.assert_called_with(expected, black.Changed.YES)
1685
1686     def test_reformat_one_with_stdin_empty(self) -> None:
1687         output = io.StringIO()
1688         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1689             try:
1690                 black.format_stdin_to_stdout(
1691                     fast=True,
1692                     content="",
1693                     write_back=black.WriteBack.YES,
1694                     mode=DEFAULT_MODE,
1695                 )
1696             except io.UnsupportedOperation:
1697                 pass  # StringIO does not support detach
1698             assert output.getvalue() == ""
1699
1700     def test_gitignore_exclude(self) -> None:
1701         path = THIS_DIR / "data" / "include_exclude_tests"
1702         include = re.compile(r"\.pyi?$")
1703         exclude = re.compile(r"")
1704         report = black.Report()
1705         gitignore = PathSpec.from_lines(
1706             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1707         )
1708         sources: List[Path] = []
1709         expected = [
1710             Path(path / "b/dont_exclude/a.py"),
1711             Path(path / "b/dont_exclude/a.pyi"),
1712         ]
1713         this_abs = THIS_DIR.resolve()
1714         sources.extend(
1715             black.gen_python_files(
1716                 path.iterdir(),
1717                 this_abs,
1718                 include,
1719                 exclude,
1720                 None,
1721                 None,
1722                 report,
1723                 gitignore,
1724             )
1725         )
1726         self.assertEqual(sorted(expected), sorted(sources))
1727
1728     def test_nested_gitignore(self) -> None:
1729         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1730         include = re.compile(r"\.pyi?$")
1731         exclude = re.compile(r"")
1732         root_gitignore = black.files.get_gitignore(path)
1733         report = black.Report()
1734         expected: List[Path] = [
1735             Path(path / "x.py"),
1736             Path(path / "root/b.py"),
1737             Path(path / "root/c.py"),
1738             Path(path / "root/child/c.py"),
1739         ]
1740         this_abs = THIS_DIR.resolve()
1741         sources = list(
1742             black.gen_python_files(
1743                 path.iterdir(),
1744                 this_abs,
1745                 include,
1746                 exclude,
1747                 None,
1748                 None,
1749                 report,
1750                 root_gitignore,
1751             )
1752         )
1753         self.assertEqual(sorted(expected), sorted(sources))
1754
1755     def test_empty_include(self) -> None:
1756         path = THIS_DIR / "data" / "include_exclude_tests"
1757         report = black.Report()
1758         gitignore = PathSpec.from_lines("gitwildmatch", [])
1759         empty = re.compile(r"")
1760         sources: List[Path] = []
1761         expected = [
1762             Path(path / "b/exclude/a.pie"),
1763             Path(path / "b/exclude/a.py"),
1764             Path(path / "b/exclude/a.pyi"),
1765             Path(path / "b/dont_exclude/a.pie"),
1766             Path(path / "b/dont_exclude/a.py"),
1767             Path(path / "b/dont_exclude/a.pyi"),
1768             Path(path / "b/.definitely_exclude/a.pie"),
1769             Path(path / "b/.definitely_exclude/a.py"),
1770             Path(path / "b/.definitely_exclude/a.pyi"),
1771             Path(path / ".gitignore"),
1772             Path(path / "pyproject.toml"),
1773         ]
1774         this_abs = THIS_DIR.resolve()
1775         sources.extend(
1776             black.gen_python_files(
1777                 path.iterdir(),
1778                 this_abs,
1779                 empty,
1780                 re.compile(black.DEFAULT_EXCLUDES),
1781                 None,
1782                 None,
1783                 report,
1784                 gitignore,
1785             )
1786         )
1787         self.assertEqual(sorted(expected), sorted(sources))
1788
1789     def test_extend_exclude(self) -> None:
1790         path = THIS_DIR / "data" / "include_exclude_tests"
1791         report = black.Report()
1792         gitignore = PathSpec.from_lines("gitwildmatch", [])
1793         sources: List[Path] = []
1794         expected = [
1795             Path(path / "b/exclude/a.py"),
1796             Path(path / "b/dont_exclude/a.py"),
1797         ]
1798         this_abs = THIS_DIR.resolve()
1799         sources.extend(
1800             black.gen_python_files(
1801                 path.iterdir(),
1802                 this_abs,
1803                 re.compile(black.DEFAULT_INCLUDES),
1804                 re.compile(r"\.pyi$"),
1805                 re.compile(r"\.definitely_exclude"),
1806                 None,
1807                 report,
1808                 gitignore,
1809             )
1810         )
1811         self.assertEqual(sorted(expected), sorted(sources))
1812
1813     def test_invalid_cli_regex(self) -> None:
1814         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1815             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1816
1817     def test_required_version_matches_version(self) -> None:
1818         self.invokeBlack(
1819             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1820         )
1821
1822     def test_required_version_does_not_match_version(self) -> None:
1823         self.invokeBlack(
1824             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1825         )
1826
1827     def test_preserves_line_endings(self) -> None:
1828         with TemporaryDirectory() as workspace:
1829             test_file = Path(workspace) / "test.py"
1830             for nl in ["\n", "\r\n"]:
1831                 contents = nl.join(["def f(  ):", "    pass"])
1832                 test_file.write_bytes(contents.encode())
1833                 ff(test_file, write_back=black.WriteBack.YES)
1834                 updated_contents: bytes = test_file.read_bytes()
1835                 self.assertIn(nl.encode(), updated_contents)
1836                 if nl == "\n":
1837                     self.assertNotIn(b"\r\n", updated_contents)
1838
1839     def test_preserves_line_endings_via_stdin(self) -> None:
1840         for nl in ["\n", "\r\n"]:
1841             contents = nl.join(["def f(  ):", "    pass"])
1842             runner = BlackRunner()
1843             result = runner.invoke(
1844                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1845             )
1846             self.assertEqual(result.exit_code, 0)
1847             output = result.stdout_bytes
1848             self.assertIn(nl.encode("utf8"), output)
1849             if nl == "\n":
1850                 self.assertNotIn(b"\r\n", output)
1851
1852     def test_assert_equivalent_different_asts(self) -> None:
1853         with self.assertRaises(AssertionError):
1854             black.assert_equivalent("{}", "None")
1855
1856     def test_symlink_out_of_root_directory(self) -> None:
1857         path = MagicMock()
1858         root = THIS_DIR.resolve()
1859         child = MagicMock()
1860         include = re.compile(black.DEFAULT_INCLUDES)
1861         exclude = re.compile(black.DEFAULT_EXCLUDES)
1862         report = black.Report()
1863         gitignore = PathSpec.from_lines("gitwildmatch", [])
1864         # `child` should behave like a symlink which resolved path is clearly
1865         # outside of the `root` directory.
1866         path.iterdir.return_value = [child]
1867         child.resolve.return_value = Path("/a/b/c")
1868         child.as_posix.return_value = "/a/b/c"
1869         child.is_symlink.return_value = True
1870         try:
1871             list(
1872                 black.gen_python_files(
1873                     path.iterdir(),
1874                     root,
1875                     include,
1876                     exclude,
1877                     None,
1878                     None,
1879                     report,
1880                     gitignore,
1881                 )
1882             )
1883         except ValueError as ve:
1884             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1885         path.iterdir.assert_called_once()
1886         child.resolve.assert_called_once()
1887         child.is_symlink.assert_called_once()
1888         # `child` should behave like a strange file which resolved path is clearly
1889         # outside of the `root` directory.
1890         child.is_symlink.return_value = False
1891         with self.assertRaises(ValueError):
1892             list(
1893                 black.gen_python_files(
1894                     path.iterdir(),
1895                     root,
1896                     include,
1897                     exclude,
1898                     None,
1899                     None,
1900                     report,
1901                     gitignore,
1902                 )
1903             )
1904         path.iterdir.assert_called()
1905         self.assertEqual(path.iterdir.call_count, 2)
1906         child.resolve.assert_called()
1907         self.assertEqual(child.resolve.call_count, 2)
1908         child.is_symlink.assert_called()
1909         self.assertEqual(child.is_symlink.call_count, 2)
1910
1911     def test_shhh_click(self) -> None:
1912         try:
1913             from click import _unicodefun
1914         except ModuleNotFoundError:
1915             self.skipTest("Incompatible Click version")
1916         if not hasattr(_unicodefun, "_verify_python3_env"):
1917             self.skipTest("Incompatible Click version")
1918         # First, let's see if Click is crashing with a preferred ASCII charset.
1919         with patch("locale.getpreferredencoding") as gpe:
1920             gpe.return_value = "ASCII"
1921             with self.assertRaises(RuntimeError):
1922                 _unicodefun._verify_python3_env()  # type: ignore
1923         # Now, let's silence Click...
1924         black.patch_click()
1925         # ...and confirm it's silent.
1926         with patch("locale.getpreferredencoding") as gpe:
1927             gpe.return_value = "ASCII"
1928             try:
1929                 _unicodefun._verify_python3_env()  # type: ignore
1930             except RuntimeError as re:
1931                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1932
1933     def test_root_logger_not_used_directly(self) -> None:
1934         def fail(*args: Any, **kwargs: Any) -> None:
1935             self.fail("Record created with root logger")
1936
1937         with patch.multiple(
1938             logging.root,
1939             debug=fail,
1940             info=fail,
1941             warning=fail,
1942             error=fail,
1943             critical=fail,
1944             log=fail,
1945         ):
1946             ff(THIS_DIR / "util.py")
1947
1948     def test_invalid_config_return_code(self) -> None:
1949         tmp_file = Path(black.dump_to_file())
1950         try:
1951             tmp_config = Path(black.dump_to_file())
1952             tmp_config.unlink()
1953             args = ["--config", str(tmp_config), str(tmp_file)]
1954             self.invokeBlack(args, exit_code=2, ignore_config=False)
1955         finally:
1956             tmp_file.unlink()
1957
1958     def test_parse_pyproject_toml(self) -> None:
1959         test_toml_file = THIS_DIR / "test.toml"
1960         config = black.parse_pyproject_toml(str(test_toml_file))
1961         self.assertEqual(config["verbose"], 1)
1962         self.assertEqual(config["check"], "no")
1963         self.assertEqual(config["diff"], "y")
1964         self.assertEqual(config["color"], True)
1965         self.assertEqual(config["line_length"], 79)
1966         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1967         self.assertEqual(config["exclude"], r"\.pyi?$")
1968         self.assertEqual(config["include"], r"\.py?$")
1969
1970     def test_read_pyproject_toml(self) -> None:
1971         test_toml_file = THIS_DIR / "test.toml"
1972         fake_ctx = FakeContext()
1973         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1974         config = fake_ctx.default_map
1975         self.assertEqual(config["verbose"], "1")
1976         self.assertEqual(config["check"], "no")
1977         self.assertEqual(config["diff"], "y")
1978         self.assertEqual(config["color"], "True")
1979         self.assertEqual(config["line_length"], "79")
1980         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1981         self.assertEqual(config["exclude"], r"\.pyi?$")
1982         self.assertEqual(config["include"], r"\.py?$")
1983
1984     def test_find_project_root(self) -> None:
1985         with TemporaryDirectory() as workspace:
1986             root = Path(workspace)
1987             test_dir = root / "test"
1988             test_dir.mkdir()
1989
1990             src_dir = root / "src"
1991             src_dir.mkdir()
1992
1993             root_pyproject = root / "pyproject.toml"
1994             root_pyproject.touch()
1995             src_pyproject = src_dir / "pyproject.toml"
1996             src_pyproject.touch()
1997             src_python = src_dir / "foo.py"
1998             src_python.touch()
1999
2000             self.assertEqual(
2001                 black.find_project_root((src_dir, test_dir)), root.resolve()
2002             )
2003             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
2004             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
2005
2006     @patch(
2007         "black.files.find_user_pyproject_toml",
2008         black.files.find_user_pyproject_toml.__wrapped__,
2009     )
2010     def test_find_user_pyproject_toml_linux(self) -> None:
2011         if system() == "Windows":
2012             return
2013
2014         # Test if XDG_CONFIG_HOME is checked
2015         with TemporaryDirectory() as workspace:
2016             tmp_user_config = Path(workspace) / "black"
2017             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
2018                 self.assertEqual(
2019                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
2020                 )
2021
2022         # Test fallback for XDG_CONFIG_HOME
2023         with patch.dict("os.environ"):
2024             os.environ.pop("XDG_CONFIG_HOME", None)
2025             fallback_user_config = Path("~/.config").expanduser() / "black"
2026             self.assertEqual(
2027                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
2028             )
2029
2030     def test_find_user_pyproject_toml_windows(self) -> None:
2031         if system() != "Windows":
2032             return
2033
2034         user_config_path = Path.home() / ".black"
2035         self.assertEqual(
2036             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2037         )
2038
2039     def test_bpo_33660_workaround(self) -> None:
2040         if system() == "Windows":
2041             return
2042
2043         # https://bugs.python.org/issue33660
2044
2045         old_cwd = Path.cwd()
2046         try:
2047             root = Path("/")
2048             os.chdir(str(root))
2049             path = Path("workspace") / "project"
2050             report = black.Report(verbose=True)
2051             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2052             self.assertEqual(normalized_path, "workspace/project")
2053         finally:
2054             os.chdir(str(old_cwd))
2055
2056     def test_newline_comment_interaction(self) -> None:
2057         source = "class A:\\\r\n# type: ignore\n pass\n"
2058         output = black.format_str(source, mode=DEFAULT_MODE)
2059         black.assert_stable(source, output, mode=DEFAULT_MODE)
2060
2061     def test_bpo_2142_workaround(self) -> None:
2062
2063         # https://bugs.python.org/issue2142
2064
2065         source, _ = read_data("missing_final_newline.py")
2066         # read_data adds a trailing newline
2067         source = source.rstrip()
2068         expected, _ = read_data("missing_final_newline.diff")
2069         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2070         diff_header = re.compile(
2071             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2072             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2073         )
2074         try:
2075             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2076             self.assertEqual(result.exit_code, 0)
2077         finally:
2078             os.unlink(tmp_file)
2079         actual = result.output
2080         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2081         self.assertEqual(actual, expected)
2082
2083     @pytest.mark.python2
2084     def test_docstring_reformat_for_py27(self) -> None:
2085         """
2086         Check that stripping trailing whitespace from Python 2 docstrings
2087         doesn't trigger a "not equivalent to source" error
2088         """
2089         source = (
2090             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2091         )
2092         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2093
2094         result = CliRunner().invoke(
2095             black.main,
2096             ["-", "-q", "--target-version=py27"],
2097             input=BytesIO(source),
2098         )
2099
2100         self.assertEqual(result.exit_code, 0)
2101         actual = result.output
2102         self.assertFormatEqual(actual, expected)
2103
2104     @staticmethod
2105     def compare_results(
2106         result: click.testing.Result, expected_value: str, expected_exit_code: int
2107     ) -> None:
2108         """Helper method to test the value and exit code of a click Result."""
2109         assert (
2110             result.output == expected_value
2111         ), "The output did not match the expected value."
2112         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2113
2114     def test_code_option(self) -> None:
2115         """Test the code option with no changes."""
2116         code = 'print("Hello world")\n'
2117         args = ["--code", code]
2118         result = CliRunner().invoke(black.main, args)
2119
2120         self.compare_results(result, code, 0)
2121
2122     def test_code_option_changed(self) -> None:
2123         """Test the code option when changes are required."""
2124         code = "print('hello world')"
2125         formatted = black.format_str(code, mode=DEFAULT_MODE)
2126
2127         args = ["--code", code]
2128         result = CliRunner().invoke(black.main, args)
2129
2130         self.compare_results(result, formatted, 0)
2131
2132     def test_code_option_check(self) -> None:
2133         """Test the code option when check is passed."""
2134         args = ["--check", "--code", 'print("Hello world")\n']
2135         result = CliRunner().invoke(black.main, args)
2136         self.compare_results(result, "", 0)
2137
2138     def test_code_option_check_changed(self) -> None:
2139         """Test the code option when changes are required, and check is passed."""
2140         args = ["--check", "--code", "print('hello world')"]
2141         result = CliRunner().invoke(black.main, args)
2142         self.compare_results(result, "", 1)
2143
2144     def test_code_option_diff(self) -> None:
2145         """Test the code option when diff is passed."""
2146         code = "print('hello world')"
2147         formatted = black.format_str(code, mode=DEFAULT_MODE)
2148         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2149
2150         args = ["--diff", "--code", code]
2151         result = CliRunner().invoke(black.main, args)
2152
2153         # Remove time from diff
2154         output = DIFF_TIME.sub("", result.output)
2155
2156         assert output == result_diff, "The output did not match the expected value."
2157         assert result.exit_code == 0, "The exit code is incorrect."
2158
2159     def test_code_option_color_diff(self) -> None:
2160         """Test the code option when color and diff are passed."""
2161         code = "print('hello world')"
2162         formatted = black.format_str(code, mode=DEFAULT_MODE)
2163
2164         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2165         result_diff = color_diff(result_diff)
2166
2167         args = ["--diff", "--color", "--code", code]
2168         result = CliRunner().invoke(black.main, args)
2169
2170         # Remove time from diff
2171         output = DIFF_TIME.sub("", result.output)
2172
2173         assert output == result_diff, "The output did not match the expected value."
2174         assert result.exit_code == 0, "The exit code is incorrect."
2175
2176     def test_code_option_safe(self) -> None:
2177         """Test that the code option throws an error when the sanity checks fail."""
2178         # Patch black.assert_equivalent to ensure the sanity checks fail
2179         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2180             code = 'print("Hello world")'
2181             error_msg = f"{code}\nerror: cannot format <string>: \n"
2182
2183             args = ["--safe", "--code", code]
2184             result = CliRunner().invoke(black.main, args)
2185
2186             self.compare_results(result, error_msg, 123)
2187
2188     def test_code_option_fast(self) -> None:
2189         """Test that the code option ignores errors when the sanity checks fail."""
2190         # Patch black.assert_equivalent to ensure the sanity checks fail
2191         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2192             code = 'print("Hello world")'
2193             formatted = black.format_str(code, mode=DEFAULT_MODE)
2194
2195             args = ["--fast", "--code", code]
2196             result = CliRunner().invoke(black.main, args)
2197
2198             self.compare_results(result, formatted, 0)
2199
2200     def test_code_option_config(self) -> None:
2201         """
2202         Test that the code option finds the pyproject.toml in the current directory.
2203         """
2204         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2205             # Make sure we are in the project root with the pyproject file
2206             if not Path("tests").exists():
2207                 os.chdir("..")
2208
2209             args = ["--code", "print"]
2210             CliRunner().invoke(black.main, args)
2211
2212             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2213             assert (
2214                 len(parse.mock_calls) >= 1
2215             ), "Expected config parse to be called with the current directory."
2216
2217             _, call_args, _ = parse.mock_calls[0]
2218             assert (
2219                 call_args[0].lower() == str(pyproject_path).lower()
2220             ), "Incorrect config loaded."
2221
2222     def test_code_option_parent_config(self) -> None:
2223         """
2224         Test that the code option finds the pyproject.toml in the parent directory.
2225         """
2226         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2227             # Make sure we are in the tests directory
2228             if Path("tests").exists():
2229                 os.chdir("tests")
2230
2231             args = ["--code", "print"]
2232             CliRunner().invoke(black.main, args)
2233
2234             pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2235             assert (
2236                 len(parse.mock_calls) >= 1
2237             ), "Expected config parse to be called with the current directory."
2238
2239             _, call_args, _ = parse.mock_calls[0]
2240             assert (
2241                 call_args[0].lower() == str(pyproject_path).lower()
2242             ), "Incorrect config loaded."
2243
2244
2245 with open(black.__file__, "r", encoding="utf-8") as _bf:
2246     black_source_lines = _bf.readlines()
2247
2248
2249 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2250     """Show function calls `from black/__init__.py` as they happen.
2251
2252     Register this with `sys.settrace()` in a test you're debugging.
2253     """
2254     if event != "call":
2255         return tracefunc
2256
2257     stack = len(inspect.stack()) - 19
2258     stack *= 2
2259     filename = frame.f_code.co_filename
2260     lineno = frame.f_lineno
2261     func_sig_lineno = lineno - 1
2262     funcname = black_source_lines[func_sig_lineno].strip()
2263     while funcname.startswith("@"):
2264         func_sig_lineno += 1
2265         funcname = black_source_lines[func_sig_lineno].strip()
2266     if "black/__init__.py" in filename:
2267         print(f"{' ' * stack}{lineno}:{funcname}")
2268     return tracefunc
2269
2270
2271 if __name__ == "__main__":
2272     unittest.main(module="test_black")