]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix internal error when FORCE_OPTIONAL_PARENTHESES feature is enabled (#2332)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     Callable,
20     Dict,
21     List,
22     Iterator,
23     TypeVar,
24 )
25 import pytest
26 import unittest
27 from unittest.mock import patch, MagicMock
28
29 import click
30 from click import unstyle
31 from click.testing import CliRunner
32
33 import black
34 from black import Feature, TargetVersion
35 from black.cache import get_cache_file
36 from black.debug import DebugVisitor
37 from black.output import diff, color_diff
38 from black.report import Report
39 import black.files
40
41 from pathspec import PathSpec
42
43 # Import other test classes
44 from tests.util import (
45     THIS_DIR,
46     read_data,
47     DETERMINISTIC_HEADER,
48     BlackBaseTestCase,
49     DEFAULT_MODE,
50     fs,
51     ff,
52     dump_to_stderr,
53 )
54
55
56 THIS_FILE = Path(__file__)
57 PY36_VERSIONS = {
58     TargetVersion.PY36,
59     TargetVersion.PY37,
60     TargetVersion.PY38,
61     TargetVersion.PY39,
62 }
63 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
64 T = TypeVar("T")
65 R = TypeVar("R")
66
67 # Match the time output in a diff, but nothing else
68 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
69
70
71 @contextmanager
72 def cache_dir(exists: bool = True) -> Iterator[Path]:
73     with TemporaryDirectory() as workspace:
74         cache_dir = Path(workspace)
75         if not exists:
76             cache_dir = cache_dir / "new"
77         with patch("black.cache.CACHE_DIR", cache_dir):
78             yield cache_dir
79
80
81 @contextmanager
82 def event_loop() -> Iterator[None]:
83     policy = asyncio.get_event_loop_policy()
84     loop = policy.new_event_loop()
85     asyncio.set_event_loop(loop)
86     try:
87         yield
88
89     finally:
90         loop.close()
91
92
93 class FakeContext(click.Context):
94     """A fake click Context for when calling functions that need it."""
95
96     def __init__(self) -> None:
97         self.default_map: Dict[str, Any] = {}
98
99
100 class FakeParameter(click.Parameter):
101     """A fake click Parameter for when calling functions that need it."""
102
103     def __init__(self) -> None:
104         pass
105
106
107 class BlackRunner(CliRunner):
108     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
109
110     def __init__(self) -> None:
111         super().__init__(mix_stderr=False)
112
113
114 class BlackTestCase(BlackBaseTestCase):
115     def invokeBlack(
116         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
117     ) -> None:
118         runner = BlackRunner()
119         if ignore_config:
120             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
121         result = runner.invoke(black.main, args)
122         self.assertEqual(
123             result.exit_code,
124             exit_code,
125             msg=(
126                 f"Failed with args: {args}\n"
127                 f"stdout: {result.stdout_bytes.decode()!r}\n"
128                 f"stderr: {result.stderr_bytes.decode()!r}\n"
129                 f"exception: {result.exception}"
130             ),
131         )
132
133     @patch("black.dump_to_file", dump_to_stderr)
134     def test_empty(self) -> None:
135         source = expected = ""
136         actual = fs(source)
137         self.assertFormatEqual(expected, actual)
138         black.assert_equivalent(source, actual)
139         black.assert_stable(source, actual, DEFAULT_MODE)
140
141     def test_empty_ff(self) -> None:
142         expected = ""
143         tmp_file = Path(black.dump_to_file())
144         try:
145             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
146             with open(tmp_file, encoding="utf8") as f:
147                 actual = f.read()
148         finally:
149             os.unlink(tmp_file)
150         self.assertFormatEqual(expected, actual)
151
152     def test_piping(self) -> None:
153         source, expected = read_data("src/black/__init__", data=False)
154         result = BlackRunner().invoke(
155             black.main,
156             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
157             input=BytesIO(source.encode("utf8")),
158         )
159         self.assertEqual(result.exit_code, 0)
160         self.assertFormatEqual(expected, result.output)
161         if source != result.output:
162             black.assert_equivalent(source, result.output)
163             black.assert_stable(source, result.output, DEFAULT_MODE)
164
165     def test_piping_diff(self) -> None:
166         diff_header = re.compile(
167             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
168             r"\+\d\d\d\d"
169         )
170         source, _ = read_data("expression.py")
171         expected, _ = read_data("expression.diff")
172         config = THIS_DIR / "data" / "empty_pyproject.toml"
173         args = [
174             "-",
175             "--fast",
176             f"--line-length={black.DEFAULT_LINE_LENGTH}",
177             "--diff",
178             f"--config={config}",
179         ]
180         result = BlackRunner().invoke(
181             black.main, args, input=BytesIO(source.encode("utf8"))
182         )
183         self.assertEqual(result.exit_code, 0)
184         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
185         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
186         self.assertEqual(expected, actual)
187
188     def test_piping_diff_with_color(self) -> None:
189         source, _ = read_data("expression.py")
190         config = THIS_DIR / "data" / "empty_pyproject.toml"
191         args = [
192             "-",
193             "--fast",
194             f"--line-length={black.DEFAULT_LINE_LENGTH}",
195             "--diff",
196             "--color",
197             f"--config={config}",
198         ]
199         result = BlackRunner().invoke(
200             black.main, args, input=BytesIO(source.encode("utf8"))
201         )
202         actual = result.output
203         # Again, the contents are checked in a different test, so only look for colors.
204         self.assertIn("\033[1;37m", actual)
205         self.assertIn("\033[36m", actual)
206         self.assertIn("\033[32m", actual)
207         self.assertIn("\033[31m", actual)
208         self.assertIn("\033[0m", actual)
209
210     @patch("black.dump_to_file", dump_to_stderr)
211     def _test_wip(self) -> None:
212         source, expected = read_data("wip")
213         sys.settrace(tracefunc)
214         mode = replace(
215             DEFAULT_MODE,
216             experimental_string_processing=False,
217             target_versions={black.TargetVersion.PY38},
218         )
219         actual = fs(source, mode=mode)
220         sys.settrace(None)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, black.FileMode())
224
225     @unittest.expectedFailure
226     @patch("black.dump_to_file", dump_to_stderr)
227     def test_trailing_comma_optional_parens_stability1(self) -> None:
228         source, _expected = read_data("trailing_comma_optional_parens1")
229         actual = fs(source)
230         black.assert_stable(source, actual, DEFAULT_MODE)
231
232     @unittest.expectedFailure
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_trailing_comma_optional_parens_stability2(self) -> None:
235         source, _expected = read_data("trailing_comma_optional_parens2")
236         actual = fs(source)
237         black.assert_stable(source, actual, DEFAULT_MODE)
238
239     @unittest.expectedFailure
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_trailing_comma_optional_parens_stability3(self) -> None:
242         source, _expected = read_data("trailing_comma_optional_parens3")
243         actual = fs(source)
244         black.assert_stable(source, actual, DEFAULT_MODE)
245
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
248         source, _expected = read_data("trailing_comma_optional_parens1")
249         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
250         black.assert_stable(source, actual, DEFAULT_MODE)
251
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens2")
255         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens3")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_pep_572(self) -> None:
266         source, expected = read_data("pep_572")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_stable(source, actual, DEFAULT_MODE)
270         if sys.version_info >= (3, 8):
271             black.assert_equivalent(source, actual)
272
273     @patch("black.dump_to_file", dump_to_stderr)
274     def test_pep_572_remove_parens(self) -> None:
275         source, expected = read_data("pep_572_remove_parens")
276         actual = fs(source)
277         self.assertFormatEqual(expected, actual)
278         black.assert_stable(source, actual, DEFAULT_MODE)
279         if sys.version_info >= (3, 8):
280             black.assert_equivalent(source, actual)
281
282     @patch("black.dump_to_file", dump_to_stderr)
283     def test_pep_572_do_not_remove_parens(self) -> None:
284         source, expected = read_data("pep_572_do_not_remove_parens")
285         # the AST safety checks will fail, but that's expected, just make sure no
286         # parentheses are touched
287         actual = black.format_str(source, mode=DEFAULT_MODE)
288         self.assertFormatEqual(expected, actual)
289
290     def test_pep_572_version_detection(self) -> None:
291         source, _ = read_data("pep_572")
292         root = black.lib2to3_parse(source)
293         features = black.get_features_used(root)
294         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
295         versions = black.detect_target_versions(root)
296         self.assertIn(black.TargetVersion.PY38, versions)
297
298     def test_expression_ff(self) -> None:
299         source, expected = read_data("expression")
300         tmp_file = Path(black.dump_to_file(source))
301         try:
302             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
303             with open(tmp_file, encoding="utf8") as f:
304                 actual = f.read()
305         finally:
306             os.unlink(tmp_file)
307         self.assertFormatEqual(expected, actual)
308         with patch("black.dump_to_file", dump_to_stderr):
309             black.assert_equivalent(source, actual)
310             black.assert_stable(source, actual, DEFAULT_MODE)
311
312     def test_expression_diff(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         diff_header = re.compile(
318             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
319             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
320         )
321         try:
322             result = BlackRunner().invoke(
323                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
324             )
325             self.assertEqual(result.exit_code, 0)
326         finally:
327             os.unlink(tmp_file)
328         actual = result.output
329         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
330         if expected != actual:
331             dump = black.dump_to_file(actual)
332             msg = (
333                 "Expected diff isn't equal to the actual. If you made changes to"
334                 " expression.py and this is an anticipated difference, overwrite"
335                 f" tests/data/expression.diff with {dump}"
336             )
337             self.assertEqual(expected, actual, msg)
338
339     def test_expression_diff_with_color(self) -> None:
340         source, _ = read_data("expression.py")
341         config = THIS_DIR / "data" / "empty_pyproject.toml"
342         expected, _ = read_data("expression.diff")
343         tmp_file = Path(black.dump_to_file(source))
344         try:
345             result = BlackRunner().invoke(
346                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
347             )
348         finally:
349             os.unlink(tmp_file)
350         actual = result.output
351         # We check the contents of the diff in `test_expression_diff`. All
352         # we need to check here is that color codes exist in the result.
353         self.assertIn("\033[1;37m", actual)
354         self.assertIn("\033[36m", actual)
355         self.assertIn("\033[32m", actual)
356         self.assertIn("\033[31m", actual)
357         self.assertIn("\033[0m", actual)
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_pep_570(self) -> None:
361         source, expected = read_data("pep_570")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_stable(source, actual, DEFAULT_MODE)
365         if sys.version_info >= (3, 8):
366             black.assert_equivalent(source, actual)
367
368     def test_detect_pos_only_arguments(self) -> None:
369         source, _ = read_data("pep_570")
370         root = black.lib2to3_parse(source)
371         features = black.get_features_used(root)
372         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
373         versions = black.detect_target_versions(root)
374         self.assertIn(black.TargetVersion.PY38, versions)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_string_quotes(self) -> None:
378         source, expected = read_data("string_quotes")
379         mode = black.Mode(experimental_string_processing=True)
380         actual = fs(source, mode=mode)
381         self.assertFormatEqual(expected, actual)
382         black.assert_equivalent(source, actual)
383         black.assert_stable(source, actual, mode)
384         mode = replace(mode, string_normalization=False)
385         not_normalized = fs(source, mode=mode)
386         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
387         black.assert_equivalent(source, not_normalized)
388         black.assert_stable(source, not_normalized, mode=mode)
389
390     @patch("black.dump_to_file", dump_to_stderr)
391     def test_docstring_no_string_normalization(self) -> None:
392         """Like test_docstring but with string normalization off."""
393         source, expected = read_data("docstring_no_string_normalization")
394         mode = replace(DEFAULT_MODE, string_normalization=False)
395         actual = fs(source, mode=mode)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, mode)
399
400     def test_long_strings_flag_disabled(self) -> None:
401         """Tests for turning off the string processing logic."""
402         source, expected = read_data("long_strings_flag_disabled")
403         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
404         actual = fs(source, mode=mode)
405         self.assertFormatEqual(expected, actual)
406         black.assert_stable(expected, actual, mode)
407
408     @patch("black.dump_to_file", dump_to_stderr)
409     def test_numeric_literals(self) -> None:
410         source, expected = read_data("numeric_literals")
411         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
412         actual = fs(source, mode=mode)
413         self.assertFormatEqual(expected, actual)
414         black.assert_equivalent(source, actual)
415         black.assert_stable(source, actual, mode)
416
417     @patch("black.dump_to_file", dump_to_stderr)
418     def test_numeric_literals_ignoring_underscores(self) -> None:
419         source, expected = read_data("numeric_literals_skip_underscores")
420         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
421         actual = fs(source, mode=mode)
422         self.assertFormatEqual(expected, actual)
423         black.assert_equivalent(source, actual)
424         black.assert_stable(source, actual, mode)
425
426     def test_skip_magic_trailing_comma(self) -> None:
427         source, _ = read_data("expression.py")
428         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
429         tmp_file = Path(black.dump_to_file(source))
430         diff_header = re.compile(
431             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
432             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
433         )
434         try:
435             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
436             self.assertEqual(result.exit_code, 0)
437         finally:
438             os.unlink(tmp_file)
439         actual = result.output
440         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
441         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
442         if expected != actual:
443             dump = black.dump_to_file(actual)
444             msg = (
445                 "Expected diff isn't equal to the actual. If you made changes to"
446                 " expression.py and this is an anticipated difference, overwrite"
447                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
448             )
449             self.assertEqual(expected, actual, msg)
450
451     @pytest.mark.no_python2
452     def test_python2_should_fail_without_optional_install(self) -> None:
453         if sys.version_info < (3, 8):
454             self.skipTest(
455                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
456                 " able to parse Python 2 syntax without explicitly specifying the"
457                 " python2 extra"
458             )
459
460         source = "x = 1234l"
461         tmp_file = Path(black.dump_to_file(source))
462         try:
463             runner = BlackRunner()
464             result = runner.invoke(black.main, [str(tmp_file)])
465             self.assertEqual(result.exit_code, 123)
466         finally:
467             os.unlink(tmp_file)
468         actual = (
469             result.stderr_bytes.decode()
470             .replace("\n", "")
471             .replace("\\n", "")
472             .replace("\\r", "")
473             .replace("\r", "")
474         )
475         msg = (
476             "The requested source code has invalid Python 3 syntax."
477             "If you are trying to format Python 2 files please reinstall Black"
478             " with the 'python2' extra: `python3 -m pip install black[python2]`."
479         )
480         self.assertIn(msg, actual)
481
482     @pytest.mark.python2
483     @patch("black.dump_to_file", dump_to_stderr)
484     def test_python2_print_function(self) -> None:
485         source, expected = read_data("python2_print_function")
486         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
487         actual = fs(source, mode=mode)
488         self.assertFormatEqual(expected, actual)
489         black.assert_equivalent(source, actual)
490         black.assert_stable(source, actual, mode)
491
492     @patch("black.dump_to_file", dump_to_stderr)
493     def test_stub(self) -> None:
494         mode = replace(DEFAULT_MODE, is_pyi=True)
495         source, expected = read_data("stub.pyi")
496         actual = fs(source, mode=mode)
497         self.assertFormatEqual(expected, actual)
498         black.assert_stable(source, actual, mode)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_async_as_identifier(self) -> None:
502         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
503         source, expected = read_data("async_as_identifier")
504         actual = fs(source)
505         self.assertFormatEqual(expected, actual)
506         major, minor = sys.version_info[:2]
507         if major < 3 or (major <= 3 and minor < 7):
508             black.assert_equivalent(source, actual)
509         black.assert_stable(source, actual, DEFAULT_MODE)
510         # ensure black can parse this when the target is 3.6
511         self.invokeBlack([str(source_path), "--target-version", "py36"])
512         # but not on 3.7, because async/await is no longer an identifier
513         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_python37(self) -> None:
517         source_path = (THIS_DIR / "data" / "python37.py").resolve()
518         source, expected = read_data("python37")
519         actual = fs(source)
520         self.assertFormatEqual(expected, actual)
521         major, minor = sys.version_info[:2]
522         if major > 3 or (major == 3 and minor >= 7):
523             black.assert_equivalent(source, actual)
524         black.assert_stable(source, actual, DEFAULT_MODE)
525         # ensure black can parse this when the target is 3.7
526         self.invokeBlack([str(source_path), "--target-version", "py37"])
527         # but not on 3.6, because we use async as a reserved keyword
528         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_python38(self) -> None:
532         source, expected = read_data("python38")
533         actual = fs(source)
534         self.assertFormatEqual(expected, actual)
535         major, minor = sys.version_info[:2]
536         if major > 3 or (major == 3 and minor >= 8):
537             black.assert_equivalent(source, actual)
538         black.assert_stable(source, actual, DEFAULT_MODE)
539
540     @patch("black.dump_to_file", dump_to_stderr)
541     def test_python39(self) -> None:
542         source, expected = read_data("python39")
543         actual = fs(source)
544         self.assertFormatEqual(expected, actual)
545         major, minor = sys.version_info[:2]
546         if major > 3 or (major == 3 and minor >= 9):
547             black.assert_equivalent(source, actual)
548         black.assert_stable(source, actual, DEFAULT_MODE)
549
550     def test_tab_comment_indentation(self) -> None:
551         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
552         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
553         self.assertFormatEqual(contents_spc, fs(contents_spc))
554         self.assertFormatEqual(contents_spc, fs(contents_tab))
555
556         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
557         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
558         self.assertFormatEqual(contents_spc, fs(contents_spc))
559         self.assertFormatEqual(contents_spc, fs(contents_tab))
560
561         # mixed tabs and spaces (valid Python 2 code)
562         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
563         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
564         self.assertFormatEqual(contents_spc, fs(contents_spc))
565         self.assertFormatEqual(contents_spc, fs(contents_tab))
566
567         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
568         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
569         self.assertFormatEqual(contents_spc, fs(contents_spc))
570         self.assertFormatEqual(contents_spc, fs(contents_tab))
571
572     def test_report_verbose(self) -> None:
573         report = Report(verbose=True)
574         out_lines = []
575         err_lines = []
576
577         def out(msg: str, **kwargs: Any) -> None:
578             out_lines.append(msg)
579
580         def err(msg: str, **kwargs: Any) -> None:
581             err_lines.append(msg)
582
583         with patch("black.output._out", out), patch("black.output._err", err):
584             report.done(Path("f1"), black.Changed.NO)
585             self.assertEqual(len(out_lines), 1)
586             self.assertEqual(len(err_lines), 0)
587             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
588             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
589             self.assertEqual(report.return_code, 0)
590             report.done(Path("f2"), black.Changed.YES)
591             self.assertEqual(len(out_lines), 2)
592             self.assertEqual(len(err_lines), 0)
593             self.assertEqual(out_lines[-1], "reformatted f2")
594             self.assertEqual(
595                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
596             )
597             report.done(Path("f3"), black.Changed.CACHED)
598             self.assertEqual(len(out_lines), 3)
599             self.assertEqual(len(err_lines), 0)
600             self.assertEqual(
601                 out_lines[-1], "f3 wasn't modified on disk since last run."
602             )
603             self.assertEqual(
604                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
605             )
606             self.assertEqual(report.return_code, 0)
607             report.check = True
608             self.assertEqual(report.return_code, 1)
609             report.check = False
610             report.failed(Path("e1"), "boom")
611             self.assertEqual(len(out_lines), 3)
612             self.assertEqual(len(err_lines), 1)
613             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
614             self.assertEqual(
615                 unstyle(str(report)),
616                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
617                 " reformat.",
618             )
619             self.assertEqual(report.return_code, 123)
620             report.done(Path("f3"), black.Changed.YES)
621             self.assertEqual(len(out_lines), 4)
622             self.assertEqual(len(err_lines), 1)
623             self.assertEqual(out_lines[-1], "reformatted f3")
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
627                 " reformat.",
628             )
629             self.assertEqual(report.return_code, 123)
630             report.failed(Path("e2"), "boom")
631             self.assertEqual(len(out_lines), 4)
632             self.assertEqual(len(err_lines), 2)
633             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
634             self.assertEqual(
635                 unstyle(str(report)),
636                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
637                 " reformat.",
638             )
639             self.assertEqual(report.return_code, 123)
640             report.path_ignored(Path("wat"), "no match")
641             self.assertEqual(len(out_lines), 5)
642             self.assertEqual(len(err_lines), 2)
643             self.assertEqual(out_lines[-1], "wat ignored: no match")
644             self.assertEqual(
645                 unstyle(str(report)),
646                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
647                 " reformat.",
648             )
649             self.assertEqual(report.return_code, 123)
650             report.done(Path("f4"), black.Changed.NO)
651             self.assertEqual(len(out_lines), 6)
652             self.assertEqual(len(err_lines), 2)
653             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
654             self.assertEqual(
655                 unstyle(str(report)),
656                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
657                 " reformat.",
658             )
659             self.assertEqual(report.return_code, 123)
660             report.check = True
661             self.assertEqual(
662                 unstyle(str(report)),
663                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
664                 " would fail to reformat.",
665             )
666             report.check = False
667             report.diff = True
668             self.assertEqual(
669                 unstyle(str(report)),
670                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
671                 " would fail to reformat.",
672             )
673
674     def test_report_quiet(self) -> None:
675         report = Report(quiet=True)
676         out_lines = []
677         err_lines = []
678
679         def out(msg: str, **kwargs: Any) -> None:
680             out_lines.append(msg)
681
682         def err(msg: str, **kwargs: Any) -> None:
683             err_lines.append(msg)
684
685         with patch("black.output._out", out), patch("black.output._err", err):
686             report.done(Path("f1"), black.Changed.NO)
687             self.assertEqual(len(out_lines), 0)
688             self.assertEqual(len(err_lines), 0)
689             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
690             self.assertEqual(report.return_code, 0)
691             report.done(Path("f2"), black.Changed.YES)
692             self.assertEqual(len(out_lines), 0)
693             self.assertEqual(len(err_lines), 0)
694             self.assertEqual(
695                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
696             )
697             report.done(Path("f3"), black.Changed.CACHED)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 0)
700             self.assertEqual(
701                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
702             )
703             self.assertEqual(report.return_code, 0)
704             report.check = True
705             self.assertEqual(report.return_code, 1)
706             report.check = False
707             report.failed(Path("e1"), "boom")
708             self.assertEqual(len(out_lines), 0)
709             self.assertEqual(len(err_lines), 1)
710             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
711             self.assertEqual(
712                 unstyle(str(report)),
713                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
714                 " reformat.",
715             )
716             self.assertEqual(report.return_code, 123)
717             report.done(Path("f3"), black.Changed.YES)
718             self.assertEqual(len(out_lines), 0)
719             self.assertEqual(len(err_lines), 1)
720             self.assertEqual(
721                 unstyle(str(report)),
722                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
723                 " reformat.",
724             )
725             self.assertEqual(report.return_code, 123)
726             report.failed(Path("e2"), "boom")
727             self.assertEqual(len(out_lines), 0)
728             self.assertEqual(len(err_lines), 2)
729             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
730             self.assertEqual(
731                 unstyle(str(report)),
732                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
733                 " reformat.",
734             )
735             self.assertEqual(report.return_code, 123)
736             report.path_ignored(Path("wat"), "no match")
737             self.assertEqual(len(out_lines), 0)
738             self.assertEqual(len(err_lines), 2)
739             self.assertEqual(
740                 unstyle(str(report)),
741                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
742                 " reformat.",
743             )
744             self.assertEqual(report.return_code, 123)
745             report.done(Path("f4"), black.Changed.NO)
746             self.assertEqual(len(out_lines), 0)
747             self.assertEqual(len(err_lines), 2)
748             self.assertEqual(
749                 unstyle(str(report)),
750                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
751                 " reformat.",
752             )
753             self.assertEqual(report.return_code, 123)
754             report.check = True
755             self.assertEqual(
756                 unstyle(str(report)),
757                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
758                 " would fail to reformat.",
759             )
760             report.check = False
761             report.diff = True
762             self.assertEqual(
763                 unstyle(str(report)),
764                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
765                 " would fail to reformat.",
766             )
767
768     def test_report_normal(self) -> None:
769         report = black.Report()
770         out_lines = []
771         err_lines = []
772
773         def out(msg: str, **kwargs: Any) -> None:
774             out_lines.append(msg)
775
776         def err(msg: str, **kwargs: Any) -> None:
777             err_lines.append(msg)
778
779         with patch("black.output._out", out), patch("black.output._err", err):
780             report.done(Path("f1"), black.Changed.NO)
781             self.assertEqual(len(out_lines), 0)
782             self.assertEqual(len(err_lines), 0)
783             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
784             self.assertEqual(report.return_code, 0)
785             report.done(Path("f2"), black.Changed.YES)
786             self.assertEqual(len(out_lines), 1)
787             self.assertEqual(len(err_lines), 0)
788             self.assertEqual(out_lines[-1], "reformatted f2")
789             self.assertEqual(
790                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
791             )
792             report.done(Path("f3"), black.Changed.CACHED)
793             self.assertEqual(len(out_lines), 1)
794             self.assertEqual(len(err_lines), 0)
795             self.assertEqual(out_lines[-1], "reformatted f2")
796             self.assertEqual(
797                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
798             )
799             self.assertEqual(report.return_code, 0)
800             report.check = True
801             self.assertEqual(report.return_code, 1)
802             report.check = False
803             report.failed(Path("e1"), "boom")
804             self.assertEqual(len(out_lines), 1)
805             self.assertEqual(len(err_lines), 1)
806             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
807             self.assertEqual(
808                 unstyle(str(report)),
809                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
810                 " reformat.",
811             )
812             self.assertEqual(report.return_code, 123)
813             report.done(Path("f3"), black.Changed.YES)
814             self.assertEqual(len(out_lines), 2)
815             self.assertEqual(len(err_lines), 1)
816             self.assertEqual(out_lines[-1], "reformatted f3")
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
820                 " reformat.",
821             )
822             self.assertEqual(report.return_code, 123)
823             report.failed(Path("e2"), "boom")
824             self.assertEqual(len(out_lines), 2)
825             self.assertEqual(len(err_lines), 2)
826             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
827             self.assertEqual(
828                 unstyle(str(report)),
829                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
830                 " reformat.",
831             )
832             self.assertEqual(report.return_code, 123)
833             report.path_ignored(Path("wat"), "no match")
834             self.assertEqual(len(out_lines), 2)
835             self.assertEqual(len(err_lines), 2)
836             self.assertEqual(
837                 unstyle(str(report)),
838                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
839                 " reformat.",
840             )
841             self.assertEqual(report.return_code, 123)
842             report.done(Path("f4"), black.Changed.NO)
843             self.assertEqual(len(out_lines), 2)
844             self.assertEqual(len(err_lines), 2)
845             self.assertEqual(
846                 unstyle(str(report)),
847                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
848                 " reformat.",
849             )
850             self.assertEqual(report.return_code, 123)
851             report.check = True
852             self.assertEqual(
853                 unstyle(str(report)),
854                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
855                 " would fail to reformat.",
856             )
857             report.check = False
858             report.diff = True
859             self.assertEqual(
860                 unstyle(str(report)),
861                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
862                 " would fail to reformat.",
863             )
864
865     def test_lib2to3_parse(self) -> None:
866         with self.assertRaises(black.InvalidInput):
867             black.lib2to3_parse("invalid syntax")
868
869         straddling = "x + y"
870         black.lib2to3_parse(straddling)
871         black.lib2to3_parse(straddling, {TargetVersion.PY27})
872         black.lib2to3_parse(straddling, {TargetVersion.PY36})
873         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
874
875         py2_only = "print x"
876         black.lib2to3_parse(py2_only)
877         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
878         with self.assertRaises(black.InvalidInput):
879             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
880         with self.assertRaises(black.InvalidInput):
881             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
882
883         py3_only = "exec(x, end=y)"
884         black.lib2to3_parse(py3_only)
885         with self.assertRaises(black.InvalidInput):
886             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
887         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
888         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
889
890     def test_get_features_used_decorator(self) -> None:
891         # Test the feature detection of new decorator syntax
892         # since this makes some test cases of test_get_features_used()
893         # fails if it fails, this is tested first so that a useful case
894         # is identified
895         simples, relaxed = read_data("decorators")
896         # skip explanation comments at the top of the file
897         for simple_test in simples.split("##")[1:]:
898             node = black.lib2to3_parse(simple_test)
899             decorator = str(node.children[0].children[0]).strip()
900             self.assertNotIn(
901                 Feature.RELAXED_DECORATORS,
902                 black.get_features_used(node),
903                 msg=(
904                     f"decorator '{decorator}' follows python<=3.8 syntax"
905                     "but is detected as 3.9+"
906                     # f"The full node is\n{node!r}"
907                 ),
908             )
909         # skip the '# output' comment at the top of the output part
910         for relaxed_test in relaxed.split("##")[1:]:
911             node = black.lib2to3_parse(relaxed_test)
912             decorator = str(node.children[0].children[0]).strip()
913             self.assertIn(
914                 Feature.RELAXED_DECORATORS,
915                 black.get_features_used(node),
916                 msg=(
917                     f"decorator '{decorator}' uses python3.9+ syntax"
918                     "but is detected as python<=3.8"
919                     # f"The full node is\n{node!r}"
920                 ),
921             )
922
923     def test_get_features_used(self) -> None:
924         node = black.lib2to3_parse("def f(*, arg): ...\n")
925         self.assertEqual(black.get_features_used(node), set())
926         node = black.lib2to3_parse("def f(*, arg,): ...\n")
927         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
928         node = black.lib2to3_parse("f(*arg,)\n")
929         self.assertEqual(
930             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
931         )
932         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
933         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
934         node = black.lib2to3_parse("123_456\n")
935         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
936         node = black.lib2to3_parse("123456\n")
937         self.assertEqual(black.get_features_used(node), set())
938         source, expected = read_data("function")
939         node = black.lib2to3_parse(source)
940         expected_features = {
941             Feature.TRAILING_COMMA_IN_CALL,
942             Feature.TRAILING_COMMA_IN_DEF,
943             Feature.F_STRINGS,
944         }
945         self.assertEqual(black.get_features_used(node), expected_features)
946         node = black.lib2to3_parse(expected)
947         self.assertEqual(black.get_features_used(node), expected_features)
948         source, expected = read_data("expression")
949         node = black.lib2to3_parse(source)
950         self.assertEqual(black.get_features_used(node), set())
951         node = black.lib2to3_parse(expected)
952         self.assertEqual(black.get_features_used(node), set())
953
954     def test_get_future_imports(self) -> None:
955         node = black.lib2to3_parse("\n")
956         self.assertEqual(set(), black.get_future_imports(node))
957         node = black.lib2to3_parse("from __future__ import black\n")
958         self.assertEqual({"black"}, black.get_future_imports(node))
959         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
960         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
961         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
962         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
963         node = black.lib2to3_parse(
964             "from __future__ import multiple\nfrom __future__ import imports\n"
965         )
966         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
968         self.assertEqual({"black"}, black.get_future_imports(node))
969         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
970         self.assertEqual({"black"}, black.get_future_imports(node))
971         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
972         self.assertEqual(set(), black.get_future_imports(node))
973         node = black.lib2to3_parse("from some.module import black\n")
974         self.assertEqual(set(), black.get_future_imports(node))
975         node = black.lib2to3_parse(
976             "from __future__ import unicode_literals as _unicode_literals"
977         )
978         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
979         node = black.lib2to3_parse(
980             "from __future__ import unicode_literals as _lol, print"
981         )
982         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
983
984     def test_debug_visitor(self) -> None:
985         source, _ = read_data("debug_visitor.py")
986         expected, _ = read_data("debug_visitor.out")
987         out_lines = []
988         err_lines = []
989
990         def out(msg: str, **kwargs: Any) -> None:
991             out_lines.append(msg)
992
993         def err(msg: str, **kwargs: Any) -> None:
994             err_lines.append(msg)
995
996         with patch("black.debug.out", out):
997             DebugVisitor.show(source)
998         actual = "\n".join(out_lines) + "\n"
999         log_name = ""
1000         if expected != actual:
1001             log_name = black.dump_to_file(*out_lines)
1002         self.assertEqual(
1003             expected,
1004             actual,
1005             f"AST print out is different. Actual version dumped to {log_name}",
1006         )
1007
1008     def test_format_file_contents(self) -> None:
1009         empty = ""
1010         mode = DEFAULT_MODE
1011         with self.assertRaises(black.NothingChanged):
1012             black.format_file_contents(empty, mode=mode, fast=False)
1013         just_nl = "\n"
1014         with self.assertRaises(black.NothingChanged):
1015             black.format_file_contents(just_nl, mode=mode, fast=False)
1016         same = "j = [1, 2, 3]\n"
1017         with self.assertRaises(black.NothingChanged):
1018             black.format_file_contents(same, mode=mode, fast=False)
1019         different = "j = [1,2,3]"
1020         expected = same
1021         actual = black.format_file_contents(different, mode=mode, fast=False)
1022         self.assertEqual(expected, actual)
1023         invalid = "return if you can"
1024         with self.assertRaises(black.InvalidInput) as e:
1025             black.format_file_contents(invalid, mode=mode, fast=False)
1026         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1027
1028     def test_endmarker(self) -> None:
1029         n = black.lib2to3_parse("\n")
1030         self.assertEqual(n.type, black.syms.file_input)
1031         self.assertEqual(len(n.children), 1)
1032         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1033
1034     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1035     def test_assertFormatEqual(self) -> None:
1036         out_lines = []
1037         err_lines = []
1038
1039         def out(msg: str, **kwargs: Any) -> None:
1040             out_lines.append(msg)
1041
1042         def err(msg: str, **kwargs: Any) -> None:
1043             err_lines.append(msg)
1044
1045         with patch("black.output._out", out), patch("black.output._err", err):
1046             with self.assertRaises(AssertionError):
1047                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1048
1049         out_str = "".join(out_lines)
1050         self.assertTrue("Expected tree:" in out_str)
1051         self.assertTrue("Actual tree:" in out_str)
1052         self.assertEqual("".join(err_lines), "")
1053
1054     def test_cache_broken_file(self) -> None:
1055         mode = DEFAULT_MODE
1056         with cache_dir() as workspace:
1057             cache_file = get_cache_file(mode)
1058             with cache_file.open("w") as fobj:
1059                 fobj.write("this is not a pickle")
1060             self.assertEqual(black.read_cache(mode), {})
1061             src = (workspace / "test.py").resolve()
1062             with src.open("w") as fobj:
1063                 fobj.write("print('hello')")
1064             self.invokeBlack([str(src)])
1065             cache = black.read_cache(mode)
1066             self.assertIn(str(src), cache)
1067
1068     def test_cache_single_file_already_cached(self) -> None:
1069         mode = DEFAULT_MODE
1070         with cache_dir() as workspace:
1071             src = (workspace / "test.py").resolve()
1072             with src.open("w") as fobj:
1073                 fobj.write("print('hello')")
1074             black.write_cache({}, [src], mode)
1075             self.invokeBlack([str(src)])
1076             with src.open("r") as fobj:
1077                 self.assertEqual(fobj.read(), "print('hello')")
1078
1079     @event_loop()
1080     def test_cache_multiple_files(self) -> None:
1081         mode = DEFAULT_MODE
1082         with cache_dir() as workspace, patch(
1083             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1084         ):
1085             one = (workspace / "one.py").resolve()
1086             with one.open("w") as fobj:
1087                 fobj.write("print('hello')")
1088             two = (workspace / "two.py").resolve()
1089             with two.open("w") as fobj:
1090                 fobj.write("print('hello')")
1091             black.write_cache({}, [one], mode)
1092             self.invokeBlack([str(workspace)])
1093             with one.open("r") as fobj:
1094                 self.assertEqual(fobj.read(), "print('hello')")
1095             with two.open("r") as fobj:
1096                 self.assertEqual(fobj.read(), 'print("hello")\n')
1097             cache = black.read_cache(mode)
1098             self.assertIn(str(one), cache)
1099             self.assertIn(str(two), cache)
1100
1101     def test_no_cache_when_writeback_diff(self) -> None:
1102         mode = DEFAULT_MODE
1103         with cache_dir() as workspace:
1104             src = (workspace / "test.py").resolve()
1105             with src.open("w") as fobj:
1106                 fobj.write("print('hello')")
1107             with patch("black.read_cache") as read_cache, patch(
1108                 "black.write_cache"
1109             ) as write_cache:
1110                 self.invokeBlack([str(src), "--diff"])
1111                 cache_file = get_cache_file(mode)
1112                 self.assertFalse(cache_file.exists())
1113                 write_cache.assert_not_called()
1114                 read_cache.assert_not_called()
1115
1116     def test_no_cache_when_writeback_color_diff(self) -> None:
1117         mode = DEFAULT_MODE
1118         with cache_dir() as workspace:
1119             src = (workspace / "test.py").resolve()
1120             with src.open("w") as fobj:
1121                 fobj.write("print('hello')")
1122             with patch("black.read_cache") as read_cache, patch(
1123                 "black.write_cache"
1124             ) as write_cache:
1125                 self.invokeBlack([str(src), "--diff", "--color"])
1126                 cache_file = get_cache_file(mode)
1127                 self.assertFalse(cache_file.exists())
1128                 write_cache.assert_not_called()
1129                 read_cache.assert_not_called()
1130
1131     @event_loop()
1132     def test_output_locking_when_writeback_diff(self) -> None:
1133         with cache_dir() as workspace:
1134             for tag in range(0, 4):
1135                 src = (workspace / f"test{tag}.py").resolve()
1136                 with src.open("w") as fobj:
1137                     fobj.write("print('hello')")
1138             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1139                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1140                 # this isn't quite doing what we want, but if it _isn't_
1141                 # called then we cannot be using the lock it provides
1142                 mgr.assert_called()
1143
1144     @event_loop()
1145     def test_output_locking_when_writeback_color_diff(self) -> None:
1146         with cache_dir() as workspace:
1147             for tag in range(0, 4):
1148                 src = (workspace / f"test{tag}.py").resolve()
1149                 with src.open("w") as fobj:
1150                     fobj.write("print('hello')")
1151             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1152                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1153                 # this isn't quite doing what we want, but if it _isn't_
1154                 # called then we cannot be using the lock it provides
1155                 mgr.assert_called()
1156
1157     def test_no_cache_when_stdin(self) -> None:
1158         mode = DEFAULT_MODE
1159         with cache_dir():
1160             result = CliRunner().invoke(
1161                 black.main, ["-"], input=BytesIO(b"print('hello')")
1162             )
1163             self.assertEqual(result.exit_code, 0)
1164             cache_file = get_cache_file(mode)
1165             self.assertFalse(cache_file.exists())
1166
1167     def test_read_cache_no_cachefile(self) -> None:
1168         mode = DEFAULT_MODE
1169         with cache_dir():
1170             self.assertEqual(black.read_cache(mode), {})
1171
1172     def test_write_cache_read_cache(self) -> None:
1173         mode = DEFAULT_MODE
1174         with cache_dir() as workspace:
1175             src = (workspace / "test.py").resolve()
1176             src.touch()
1177             black.write_cache({}, [src], mode)
1178             cache = black.read_cache(mode)
1179             self.assertIn(str(src), cache)
1180             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1181
1182     def test_filter_cached(self) -> None:
1183         with TemporaryDirectory() as workspace:
1184             path = Path(workspace)
1185             uncached = (path / "uncached").resolve()
1186             cached = (path / "cached").resolve()
1187             cached_but_changed = (path / "changed").resolve()
1188             uncached.touch()
1189             cached.touch()
1190             cached_but_changed.touch()
1191             cache = {
1192                 str(cached): black.get_cache_info(cached),
1193                 str(cached_but_changed): (0.0, 0),
1194             }
1195             todo, done = black.filter_cached(
1196                 cache, {uncached, cached, cached_but_changed}
1197             )
1198             self.assertEqual(todo, {uncached, cached_but_changed})
1199             self.assertEqual(done, {cached})
1200
1201     def test_write_cache_creates_directory_if_needed(self) -> None:
1202         mode = DEFAULT_MODE
1203         with cache_dir(exists=False) as workspace:
1204             self.assertFalse(workspace.exists())
1205             black.write_cache({}, [], mode)
1206             self.assertTrue(workspace.exists())
1207
1208     @event_loop()
1209     def test_failed_formatting_does_not_get_cached(self) -> None:
1210         mode = DEFAULT_MODE
1211         with cache_dir() as workspace, patch(
1212             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1213         ):
1214             failing = (workspace / "failing.py").resolve()
1215             with failing.open("w") as fobj:
1216                 fobj.write("not actually python")
1217             clean = (workspace / "clean.py").resolve()
1218             with clean.open("w") as fobj:
1219                 fobj.write('print("hello")\n')
1220             self.invokeBlack([str(workspace)], exit_code=123)
1221             cache = black.read_cache(mode)
1222             self.assertNotIn(str(failing), cache)
1223             self.assertIn(str(clean), cache)
1224
1225     def test_write_cache_write_fail(self) -> None:
1226         mode = DEFAULT_MODE
1227         with cache_dir(), patch.object(Path, "open") as mock:
1228             mock.side_effect = OSError
1229             black.write_cache({}, [], mode)
1230
1231     @event_loop()
1232     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1233     def test_works_in_mono_process_only_environment(self) -> None:
1234         with cache_dir() as workspace:
1235             for f in [
1236                 (workspace / "one.py").resolve(),
1237                 (workspace / "two.py").resolve(),
1238             ]:
1239                 f.write_text('print("hello")\n')
1240             self.invokeBlack([str(workspace)])
1241
1242     @event_loop()
1243     def test_check_diff_use_together(self) -> None:
1244         with cache_dir():
1245             # Files which will be reformatted.
1246             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1247             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1248             # Files which will not be reformatted.
1249             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1250             self.invokeBlack([str(src2), "--diff", "--check"])
1251             # Multi file command.
1252             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1253
1254     def test_no_files(self) -> None:
1255         with cache_dir():
1256             # Without an argument, black exits with error code 0.
1257             self.invokeBlack([])
1258
1259     def test_broken_symlink(self) -> None:
1260         with cache_dir() as workspace:
1261             symlink = workspace / "broken_link.py"
1262             try:
1263                 symlink.symlink_to("nonexistent.py")
1264             except OSError as e:
1265                 self.skipTest(f"Can't create symlinks: {e}")
1266             self.invokeBlack([str(workspace.resolve())])
1267
1268     def test_read_cache_line_lengths(self) -> None:
1269         mode = DEFAULT_MODE
1270         short_mode = replace(DEFAULT_MODE, line_length=1)
1271         with cache_dir() as workspace:
1272             path = (workspace / "file.py").resolve()
1273             path.touch()
1274             black.write_cache({}, [path], mode)
1275             one = black.read_cache(mode)
1276             self.assertIn(str(path), one)
1277             two = black.read_cache(short_mode)
1278             self.assertNotIn(str(path), two)
1279
1280     def test_single_file_force_pyi(self) -> None:
1281         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1282         contents, expected = read_data("force_pyi")
1283         with cache_dir() as workspace:
1284             path = (workspace / "file.py").resolve()
1285             with open(path, "w") as fh:
1286                 fh.write(contents)
1287             self.invokeBlack([str(path), "--pyi"])
1288             with open(path, "r") as fh:
1289                 actual = fh.read()
1290             # verify cache with --pyi is separate
1291             pyi_cache = black.read_cache(pyi_mode)
1292             self.assertIn(str(path), pyi_cache)
1293             normal_cache = black.read_cache(DEFAULT_MODE)
1294             self.assertNotIn(str(path), normal_cache)
1295         self.assertFormatEqual(expected, actual)
1296         black.assert_equivalent(contents, actual)
1297         black.assert_stable(contents, actual, pyi_mode)
1298
1299     @event_loop()
1300     def test_multi_file_force_pyi(self) -> None:
1301         reg_mode = DEFAULT_MODE
1302         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1303         contents, expected = read_data("force_pyi")
1304         with cache_dir() as workspace:
1305             paths = [
1306                 (workspace / "file1.py").resolve(),
1307                 (workspace / "file2.py").resolve(),
1308             ]
1309             for path in paths:
1310                 with open(path, "w") as fh:
1311                     fh.write(contents)
1312             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1313             for path in paths:
1314                 with open(path, "r") as fh:
1315                     actual = fh.read()
1316                 self.assertEqual(actual, expected)
1317             # verify cache with --pyi is separate
1318             pyi_cache = black.read_cache(pyi_mode)
1319             normal_cache = black.read_cache(reg_mode)
1320             for path in paths:
1321                 self.assertIn(str(path), pyi_cache)
1322                 self.assertNotIn(str(path), normal_cache)
1323
1324     def test_pipe_force_pyi(self) -> None:
1325         source, expected = read_data("force_pyi")
1326         result = CliRunner().invoke(
1327             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1328         )
1329         self.assertEqual(result.exit_code, 0)
1330         actual = result.output
1331         self.assertFormatEqual(actual, expected)
1332
1333     def test_single_file_force_py36(self) -> None:
1334         reg_mode = DEFAULT_MODE
1335         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1336         source, expected = read_data("force_py36")
1337         with cache_dir() as workspace:
1338             path = (workspace / "file.py").resolve()
1339             with open(path, "w") as fh:
1340                 fh.write(source)
1341             self.invokeBlack([str(path), *PY36_ARGS])
1342             with open(path, "r") as fh:
1343                 actual = fh.read()
1344             # verify cache with --target-version is separate
1345             py36_cache = black.read_cache(py36_mode)
1346             self.assertIn(str(path), py36_cache)
1347             normal_cache = black.read_cache(reg_mode)
1348             self.assertNotIn(str(path), normal_cache)
1349         self.assertEqual(actual, expected)
1350
1351     @event_loop()
1352     def test_multi_file_force_py36(self) -> None:
1353         reg_mode = DEFAULT_MODE
1354         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1355         source, expected = read_data("force_py36")
1356         with cache_dir() as workspace:
1357             paths = [
1358                 (workspace / "file1.py").resolve(),
1359                 (workspace / "file2.py").resolve(),
1360             ]
1361             for path in paths:
1362                 with open(path, "w") as fh:
1363                     fh.write(source)
1364             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1365             for path in paths:
1366                 with open(path, "r") as fh:
1367                     actual = fh.read()
1368                 self.assertEqual(actual, expected)
1369             # verify cache with --target-version is separate
1370             pyi_cache = black.read_cache(py36_mode)
1371             normal_cache = black.read_cache(reg_mode)
1372             for path in paths:
1373                 self.assertIn(str(path), pyi_cache)
1374                 self.assertNotIn(str(path), normal_cache)
1375
1376     def test_pipe_force_py36(self) -> None:
1377         source, expected = read_data("force_py36")
1378         result = CliRunner().invoke(
1379             black.main,
1380             ["-", "-q", "--target-version=py36"],
1381             input=BytesIO(source.encode("utf8")),
1382         )
1383         self.assertEqual(result.exit_code, 0)
1384         actual = result.output
1385         self.assertFormatEqual(actual, expected)
1386
1387     def test_include_exclude(self) -> None:
1388         path = THIS_DIR / "data" / "include_exclude_tests"
1389         include = re.compile(r"\.pyi?$")
1390         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1391         report = black.Report()
1392         gitignore = PathSpec.from_lines("gitwildmatch", [])
1393         sources: List[Path] = []
1394         expected = [
1395             Path(path / "b/dont_exclude/a.py"),
1396             Path(path / "b/dont_exclude/a.pyi"),
1397         ]
1398         this_abs = THIS_DIR.resolve()
1399         sources.extend(
1400             black.gen_python_files(
1401                 path.iterdir(),
1402                 this_abs,
1403                 include,
1404                 exclude,
1405                 None,
1406                 None,
1407                 report,
1408                 gitignore,
1409             )
1410         )
1411         self.assertEqual(sorted(expected), sorted(sources))
1412
1413     def test_gitignore_used_as_default(self) -> None:
1414         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1415         include = re.compile(r"\.pyi?$")
1416         extend_exclude = re.compile(r"/exclude/")
1417         src = str(path / "b/")
1418         report = black.Report()
1419         expected: List[Path] = [
1420             path / "b/.definitely_exclude/a.py",
1421             path / "b/.definitely_exclude/a.pyi",
1422         ]
1423         sources = list(
1424             black.get_sources(
1425                 ctx=FakeContext(),
1426                 src=(src,),
1427                 quiet=True,
1428                 verbose=False,
1429                 include=include,
1430                 exclude=None,
1431                 extend_exclude=extend_exclude,
1432                 force_exclude=None,
1433                 report=report,
1434                 stdin_filename=None,
1435             )
1436         )
1437         self.assertEqual(sorted(expected), sorted(sources))
1438
1439     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1440     def test_exclude_for_issue_1572(self) -> None:
1441         # Exclude shouldn't touch files that were explicitly given to Black through the
1442         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1443         # https://github.com/psf/black/issues/1572
1444         path = THIS_DIR / "data" / "include_exclude_tests"
1445         include = ""
1446         exclude = r"/exclude/|a\.py"
1447         src = str(path / "b/exclude/a.py")
1448         report = black.Report()
1449         expected = [Path(path / "b/exclude/a.py")]
1450         sources = list(
1451             black.get_sources(
1452                 ctx=FakeContext(),
1453                 src=(src,),
1454                 quiet=True,
1455                 verbose=False,
1456                 include=re.compile(include),
1457                 exclude=re.compile(exclude),
1458                 extend_exclude=None,
1459                 force_exclude=None,
1460                 report=report,
1461                 stdin_filename=None,
1462             )
1463         )
1464         self.assertEqual(sorted(expected), sorted(sources))
1465
1466     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1467     def test_get_sources_with_stdin(self) -> None:
1468         include = ""
1469         exclude = r"/exclude/|a\.py"
1470         src = "-"
1471         report = black.Report()
1472         expected = [Path("-")]
1473         sources = list(
1474             black.get_sources(
1475                 ctx=FakeContext(),
1476                 src=(src,),
1477                 quiet=True,
1478                 verbose=False,
1479                 include=re.compile(include),
1480                 exclude=re.compile(exclude),
1481                 extend_exclude=None,
1482                 force_exclude=None,
1483                 report=report,
1484                 stdin_filename=None,
1485             )
1486         )
1487         self.assertEqual(sorted(expected), sorted(sources))
1488
1489     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1490     def test_get_sources_with_stdin_filename(self) -> None:
1491         include = ""
1492         exclude = r"/exclude/|a\.py"
1493         src = "-"
1494         report = black.Report()
1495         stdin_filename = str(THIS_DIR / "data/collections.py")
1496         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1497         sources = list(
1498             black.get_sources(
1499                 ctx=FakeContext(),
1500                 src=(src,),
1501                 quiet=True,
1502                 verbose=False,
1503                 include=re.compile(include),
1504                 exclude=re.compile(exclude),
1505                 extend_exclude=None,
1506                 force_exclude=None,
1507                 report=report,
1508                 stdin_filename=stdin_filename,
1509             )
1510         )
1511         self.assertEqual(sorted(expected), sorted(sources))
1512
1513     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1514     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1515         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1516         # file being passed directly. This is the same as
1517         # test_exclude_for_issue_1572
1518         path = THIS_DIR / "data" / "include_exclude_tests"
1519         include = ""
1520         exclude = r"/exclude/|a\.py"
1521         src = "-"
1522         report = black.Report()
1523         stdin_filename = str(path / "b/exclude/a.py")
1524         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1525         sources = list(
1526             black.get_sources(
1527                 ctx=FakeContext(),
1528                 src=(src,),
1529                 quiet=True,
1530                 verbose=False,
1531                 include=re.compile(include),
1532                 exclude=re.compile(exclude),
1533                 extend_exclude=None,
1534                 force_exclude=None,
1535                 report=report,
1536                 stdin_filename=stdin_filename,
1537             )
1538         )
1539         self.assertEqual(sorted(expected), sorted(sources))
1540
1541     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1542     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1543         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1544         # file being passed directly. This is the same as
1545         # test_exclude_for_issue_1572
1546         path = THIS_DIR / "data" / "include_exclude_tests"
1547         include = ""
1548         extend_exclude = r"/exclude/|a\.py"
1549         src = "-"
1550         report = black.Report()
1551         stdin_filename = str(path / "b/exclude/a.py")
1552         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1553         sources = list(
1554             black.get_sources(
1555                 ctx=FakeContext(),
1556                 src=(src,),
1557                 quiet=True,
1558                 verbose=False,
1559                 include=re.compile(include),
1560                 exclude=re.compile(""),
1561                 extend_exclude=re.compile(extend_exclude),
1562                 force_exclude=None,
1563                 report=report,
1564                 stdin_filename=stdin_filename,
1565             )
1566         )
1567         self.assertEqual(sorted(expected), sorted(sources))
1568
1569     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1570     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1571         # Force exclude should exclude the file when passing it through
1572         # stdin_filename
1573         path = THIS_DIR / "data" / "include_exclude_tests"
1574         include = ""
1575         force_exclude = r"/exclude/|a\.py"
1576         src = "-"
1577         report = black.Report()
1578         stdin_filename = str(path / "b/exclude/a.py")
1579         sources = list(
1580             black.get_sources(
1581                 ctx=FakeContext(),
1582                 src=(src,),
1583                 quiet=True,
1584                 verbose=False,
1585                 include=re.compile(include),
1586                 exclude=re.compile(""),
1587                 extend_exclude=None,
1588                 force_exclude=re.compile(force_exclude),
1589                 report=report,
1590                 stdin_filename=stdin_filename,
1591             )
1592         )
1593         self.assertEqual([], sorted(sources))
1594
1595     def test_reformat_one_with_stdin(self) -> None:
1596         with patch(
1597             "black.format_stdin_to_stdout",
1598             return_value=lambda *args, **kwargs: black.Changed.YES,
1599         ) as fsts:
1600             report = MagicMock()
1601             path = Path("-")
1602             black.reformat_one(
1603                 path,
1604                 fast=True,
1605                 write_back=black.WriteBack.YES,
1606                 mode=DEFAULT_MODE,
1607                 report=report,
1608             )
1609             fsts.assert_called_once()
1610             report.done.assert_called_with(path, black.Changed.YES)
1611
1612     def test_reformat_one_with_stdin_filename(self) -> None:
1613         with patch(
1614             "black.format_stdin_to_stdout",
1615             return_value=lambda *args, **kwargs: black.Changed.YES,
1616         ) as fsts:
1617             report = MagicMock()
1618             p = "foo.py"
1619             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1620             expected = Path(p)
1621             black.reformat_one(
1622                 path,
1623                 fast=True,
1624                 write_back=black.WriteBack.YES,
1625                 mode=DEFAULT_MODE,
1626                 report=report,
1627             )
1628             fsts.assert_called_once_with(
1629                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1630             )
1631             # __BLACK_STDIN_FILENAME__ should have been stripped
1632             report.done.assert_called_with(expected, black.Changed.YES)
1633
1634     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1635         with patch(
1636             "black.format_stdin_to_stdout",
1637             return_value=lambda *args, **kwargs: black.Changed.YES,
1638         ) as fsts:
1639             report = MagicMock()
1640             p = "foo.pyi"
1641             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1642             expected = Path(p)
1643             black.reformat_one(
1644                 path,
1645                 fast=True,
1646                 write_back=black.WriteBack.YES,
1647                 mode=DEFAULT_MODE,
1648                 report=report,
1649             )
1650             fsts.assert_called_once_with(
1651                 fast=True,
1652                 write_back=black.WriteBack.YES,
1653                 mode=replace(DEFAULT_MODE, is_pyi=True),
1654             )
1655             # __BLACK_STDIN_FILENAME__ should have been stripped
1656             report.done.assert_called_with(expected, black.Changed.YES)
1657
1658     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1659         with patch(
1660             "black.format_stdin_to_stdout",
1661             return_value=lambda *args, **kwargs: black.Changed.YES,
1662         ) as fsts:
1663             report = MagicMock()
1664             # Even with an existing file, since we are forcing stdin, black
1665             # should output to stdout and not modify the file inplace
1666             p = Path(str(THIS_DIR / "data/collections.py"))
1667             # Make sure is_file actually returns True
1668             self.assertTrue(p.is_file())
1669             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1670             expected = Path(p)
1671             black.reformat_one(
1672                 path,
1673                 fast=True,
1674                 write_back=black.WriteBack.YES,
1675                 mode=DEFAULT_MODE,
1676                 report=report,
1677             )
1678             fsts.assert_called_once()
1679             # __BLACK_STDIN_FILENAME__ should have been stripped
1680             report.done.assert_called_with(expected, black.Changed.YES)
1681
1682     def test_gitignore_exclude(self) -> None:
1683         path = THIS_DIR / "data" / "include_exclude_tests"
1684         include = re.compile(r"\.pyi?$")
1685         exclude = re.compile(r"")
1686         report = black.Report()
1687         gitignore = PathSpec.from_lines(
1688             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1689         )
1690         sources: List[Path] = []
1691         expected = [
1692             Path(path / "b/dont_exclude/a.py"),
1693             Path(path / "b/dont_exclude/a.pyi"),
1694         ]
1695         this_abs = THIS_DIR.resolve()
1696         sources.extend(
1697             black.gen_python_files(
1698                 path.iterdir(),
1699                 this_abs,
1700                 include,
1701                 exclude,
1702                 None,
1703                 None,
1704                 report,
1705                 gitignore,
1706             )
1707         )
1708         self.assertEqual(sorted(expected), sorted(sources))
1709
1710     def test_nested_gitignore(self) -> None:
1711         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1712         include = re.compile(r"\.pyi?$")
1713         exclude = re.compile(r"")
1714         root_gitignore = black.files.get_gitignore(path)
1715         report = black.Report()
1716         expected: List[Path] = [
1717             Path(path / "x.py"),
1718             Path(path / "root/b.py"),
1719             Path(path / "root/c.py"),
1720             Path(path / "root/child/c.py"),
1721         ]
1722         this_abs = THIS_DIR.resolve()
1723         sources = list(
1724             black.gen_python_files(
1725                 path.iterdir(),
1726                 this_abs,
1727                 include,
1728                 exclude,
1729                 None,
1730                 None,
1731                 report,
1732                 root_gitignore,
1733             )
1734         )
1735         self.assertEqual(sorted(expected), sorted(sources))
1736
1737     def test_empty_include(self) -> None:
1738         path = THIS_DIR / "data" / "include_exclude_tests"
1739         report = black.Report()
1740         gitignore = PathSpec.from_lines("gitwildmatch", [])
1741         empty = re.compile(r"")
1742         sources: List[Path] = []
1743         expected = [
1744             Path(path / "b/exclude/a.pie"),
1745             Path(path / "b/exclude/a.py"),
1746             Path(path / "b/exclude/a.pyi"),
1747             Path(path / "b/dont_exclude/a.pie"),
1748             Path(path / "b/dont_exclude/a.py"),
1749             Path(path / "b/dont_exclude/a.pyi"),
1750             Path(path / "b/.definitely_exclude/a.pie"),
1751             Path(path / "b/.definitely_exclude/a.py"),
1752             Path(path / "b/.definitely_exclude/a.pyi"),
1753             Path(path / ".gitignore"),
1754             Path(path / "pyproject.toml"),
1755         ]
1756         this_abs = THIS_DIR.resolve()
1757         sources.extend(
1758             black.gen_python_files(
1759                 path.iterdir(),
1760                 this_abs,
1761                 empty,
1762                 re.compile(black.DEFAULT_EXCLUDES),
1763                 None,
1764                 None,
1765                 report,
1766                 gitignore,
1767             )
1768         )
1769         self.assertEqual(sorted(expected), sorted(sources))
1770
1771     def test_extend_exclude(self) -> None:
1772         path = THIS_DIR / "data" / "include_exclude_tests"
1773         report = black.Report()
1774         gitignore = PathSpec.from_lines("gitwildmatch", [])
1775         sources: List[Path] = []
1776         expected = [
1777             Path(path / "b/exclude/a.py"),
1778             Path(path / "b/dont_exclude/a.py"),
1779         ]
1780         this_abs = THIS_DIR.resolve()
1781         sources.extend(
1782             black.gen_python_files(
1783                 path.iterdir(),
1784                 this_abs,
1785                 re.compile(black.DEFAULT_INCLUDES),
1786                 re.compile(r"\.pyi$"),
1787                 re.compile(r"\.definitely_exclude"),
1788                 None,
1789                 report,
1790                 gitignore,
1791             )
1792         )
1793         self.assertEqual(sorted(expected), sorted(sources))
1794
1795     def test_invalid_cli_regex(self) -> None:
1796         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1797             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1798
1799     def test_required_version_matches_version(self) -> None:
1800         self.invokeBlack(
1801             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1802         )
1803
1804     def test_required_version_does_not_match_version(self) -> None:
1805         self.invokeBlack(
1806             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1807         )
1808
1809     def test_preserves_line_endings(self) -> None:
1810         with TemporaryDirectory() as workspace:
1811             test_file = Path(workspace) / "test.py"
1812             for nl in ["\n", "\r\n"]:
1813                 contents = nl.join(["def f(  ):", "    pass"])
1814                 test_file.write_bytes(contents.encode())
1815                 ff(test_file, write_back=black.WriteBack.YES)
1816                 updated_contents: bytes = test_file.read_bytes()
1817                 self.assertIn(nl.encode(), updated_contents)
1818                 if nl == "\n":
1819                     self.assertNotIn(b"\r\n", updated_contents)
1820
1821     def test_preserves_line_endings_via_stdin(self) -> None:
1822         for nl in ["\n", "\r\n"]:
1823             contents = nl.join(["def f(  ):", "    pass"])
1824             runner = BlackRunner()
1825             result = runner.invoke(
1826                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1827             )
1828             self.assertEqual(result.exit_code, 0)
1829             output = result.stdout_bytes
1830             self.assertIn(nl.encode("utf8"), output)
1831             if nl == "\n":
1832                 self.assertNotIn(b"\r\n", output)
1833
1834     def test_assert_equivalent_different_asts(self) -> None:
1835         with self.assertRaises(AssertionError):
1836             black.assert_equivalent("{}", "None")
1837
1838     def test_symlink_out_of_root_directory(self) -> None:
1839         path = MagicMock()
1840         root = THIS_DIR.resolve()
1841         child = MagicMock()
1842         include = re.compile(black.DEFAULT_INCLUDES)
1843         exclude = re.compile(black.DEFAULT_EXCLUDES)
1844         report = black.Report()
1845         gitignore = PathSpec.from_lines("gitwildmatch", [])
1846         # `child` should behave like a symlink which resolved path is clearly
1847         # outside of the `root` directory.
1848         path.iterdir.return_value = [child]
1849         child.resolve.return_value = Path("/a/b/c")
1850         child.as_posix.return_value = "/a/b/c"
1851         child.is_symlink.return_value = True
1852         try:
1853             list(
1854                 black.gen_python_files(
1855                     path.iterdir(),
1856                     root,
1857                     include,
1858                     exclude,
1859                     None,
1860                     None,
1861                     report,
1862                     gitignore,
1863                 )
1864             )
1865         except ValueError as ve:
1866             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1867         path.iterdir.assert_called_once()
1868         child.resolve.assert_called_once()
1869         child.is_symlink.assert_called_once()
1870         # `child` should behave like a strange file which resolved path is clearly
1871         # outside of the `root` directory.
1872         child.is_symlink.return_value = False
1873         with self.assertRaises(ValueError):
1874             list(
1875                 black.gen_python_files(
1876                     path.iterdir(),
1877                     root,
1878                     include,
1879                     exclude,
1880                     None,
1881                     None,
1882                     report,
1883                     gitignore,
1884                 )
1885             )
1886         path.iterdir.assert_called()
1887         self.assertEqual(path.iterdir.call_count, 2)
1888         child.resolve.assert_called()
1889         self.assertEqual(child.resolve.call_count, 2)
1890         child.is_symlink.assert_called()
1891         self.assertEqual(child.is_symlink.call_count, 2)
1892
1893     def test_shhh_click(self) -> None:
1894         try:
1895             from click import _unicodefun  # type: ignore
1896         except ModuleNotFoundError:
1897             self.skipTest("Incompatible Click version")
1898         if not hasattr(_unicodefun, "_verify_python3_env"):
1899             self.skipTest("Incompatible Click version")
1900         # First, let's see if Click is crashing with a preferred ASCII charset.
1901         with patch("locale.getpreferredencoding") as gpe:
1902             gpe.return_value = "ASCII"
1903             with self.assertRaises(RuntimeError):
1904                 _unicodefun._verify_python3_env()
1905         # Now, let's silence Click...
1906         black.patch_click()
1907         # ...and confirm it's silent.
1908         with patch("locale.getpreferredencoding") as gpe:
1909             gpe.return_value = "ASCII"
1910             try:
1911                 _unicodefun._verify_python3_env()
1912             except RuntimeError as re:
1913                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1914
1915     def test_root_logger_not_used_directly(self) -> None:
1916         def fail(*args: Any, **kwargs: Any) -> None:
1917             self.fail("Record created with root logger")
1918
1919         with patch.multiple(
1920             logging.root,
1921             debug=fail,
1922             info=fail,
1923             warning=fail,
1924             error=fail,
1925             critical=fail,
1926             log=fail,
1927         ):
1928             ff(THIS_DIR / "util.py")
1929
1930     def test_invalid_config_return_code(self) -> None:
1931         tmp_file = Path(black.dump_to_file())
1932         try:
1933             tmp_config = Path(black.dump_to_file())
1934             tmp_config.unlink()
1935             args = ["--config", str(tmp_config), str(tmp_file)]
1936             self.invokeBlack(args, exit_code=2, ignore_config=False)
1937         finally:
1938             tmp_file.unlink()
1939
1940     def test_parse_pyproject_toml(self) -> None:
1941         test_toml_file = THIS_DIR / "test.toml"
1942         config = black.parse_pyproject_toml(str(test_toml_file))
1943         self.assertEqual(config["verbose"], 1)
1944         self.assertEqual(config["check"], "no")
1945         self.assertEqual(config["diff"], "y")
1946         self.assertEqual(config["color"], True)
1947         self.assertEqual(config["line_length"], 79)
1948         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1949         self.assertEqual(config["exclude"], r"\.pyi?$")
1950         self.assertEqual(config["include"], r"\.py?$")
1951
1952     def test_read_pyproject_toml(self) -> None:
1953         test_toml_file = THIS_DIR / "test.toml"
1954         fake_ctx = FakeContext()
1955         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1956         config = fake_ctx.default_map
1957         self.assertEqual(config["verbose"], "1")
1958         self.assertEqual(config["check"], "no")
1959         self.assertEqual(config["diff"], "y")
1960         self.assertEqual(config["color"], "True")
1961         self.assertEqual(config["line_length"], "79")
1962         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1963         self.assertEqual(config["exclude"], r"\.pyi?$")
1964         self.assertEqual(config["include"], r"\.py?$")
1965
1966     def test_find_project_root(self) -> None:
1967         with TemporaryDirectory() as workspace:
1968             root = Path(workspace)
1969             test_dir = root / "test"
1970             test_dir.mkdir()
1971
1972             src_dir = root / "src"
1973             src_dir.mkdir()
1974
1975             root_pyproject = root / "pyproject.toml"
1976             root_pyproject.touch()
1977             src_pyproject = src_dir / "pyproject.toml"
1978             src_pyproject.touch()
1979             src_python = src_dir / "foo.py"
1980             src_python.touch()
1981
1982             self.assertEqual(
1983                 black.find_project_root((src_dir, test_dir)), root.resolve()
1984             )
1985             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1986             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1987
1988     @patch(
1989         "black.files.find_user_pyproject_toml",
1990         black.files.find_user_pyproject_toml.__wrapped__,
1991     )
1992     def test_find_user_pyproject_toml_linux(self) -> None:
1993         if system() == "Windows":
1994             return
1995
1996         # Test if XDG_CONFIG_HOME is checked
1997         with TemporaryDirectory() as workspace:
1998             tmp_user_config = Path(workspace) / "black"
1999             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
2000                 self.assertEqual(
2001                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
2002                 )
2003
2004         # Test fallback for XDG_CONFIG_HOME
2005         with patch.dict("os.environ"):
2006             os.environ.pop("XDG_CONFIG_HOME", None)
2007             fallback_user_config = Path("~/.config").expanduser() / "black"
2008             self.assertEqual(
2009                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
2010             )
2011
2012     def test_find_user_pyproject_toml_windows(self) -> None:
2013         if system() != "Windows":
2014             return
2015
2016         user_config_path = Path.home() / ".black"
2017         self.assertEqual(
2018             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2019         )
2020
2021     def test_bpo_33660_workaround(self) -> None:
2022         if system() == "Windows":
2023             return
2024
2025         # https://bugs.python.org/issue33660
2026
2027         old_cwd = Path.cwd()
2028         try:
2029             root = Path("/")
2030             os.chdir(str(root))
2031             path = Path("workspace") / "project"
2032             report = black.Report(verbose=True)
2033             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2034             self.assertEqual(normalized_path, "workspace/project")
2035         finally:
2036             os.chdir(str(old_cwd))
2037
2038     def test_newline_comment_interaction(self) -> None:
2039         source = "class A:\\\r\n# type: ignore\n pass\n"
2040         output = black.format_str(source, mode=DEFAULT_MODE)
2041         black.assert_stable(source, output, mode=DEFAULT_MODE)
2042
2043     def test_bpo_2142_workaround(self) -> None:
2044
2045         # https://bugs.python.org/issue2142
2046
2047         source, _ = read_data("missing_final_newline.py")
2048         # read_data adds a trailing newline
2049         source = source.rstrip()
2050         expected, _ = read_data("missing_final_newline.diff")
2051         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2052         diff_header = re.compile(
2053             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2054             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2055         )
2056         try:
2057             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2058             self.assertEqual(result.exit_code, 0)
2059         finally:
2060             os.unlink(tmp_file)
2061         actual = result.output
2062         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2063         self.assertEqual(actual, expected)
2064
2065     @pytest.mark.python2
2066     def test_docstring_reformat_for_py27(self) -> None:
2067         """
2068         Check that stripping trailing whitespace from Python 2 docstrings
2069         doesn't trigger a "not equivalent to source" error
2070         """
2071         source = (
2072             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2073         )
2074         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2075
2076         result = CliRunner().invoke(
2077             black.main,
2078             ["-", "-q", "--target-version=py27"],
2079             input=BytesIO(source),
2080         )
2081
2082         self.assertEqual(result.exit_code, 0)
2083         actual = result.output
2084         self.assertFormatEqual(actual, expected)
2085
2086     @staticmethod
2087     def compare_results(
2088         result: click.testing.Result, expected_value: str, expected_exit_code: int
2089     ) -> None:
2090         """Helper method to test the value and exit code of a click Result."""
2091         assert (
2092             result.output == expected_value
2093         ), "The output did not match the expected value."
2094         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2095
2096     def test_code_option(self) -> None:
2097         """Test the code option with no changes."""
2098         code = 'print("Hello world")\n'
2099         args = ["--code", code]
2100         result = CliRunner().invoke(black.main, args)
2101
2102         self.compare_results(result, code, 0)
2103
2104     def test_code_option_changed(self) -> None:
2105         """Test the code option when changes are required."""
2106         code = "print('hello world')"
2107         formatted = black.format_str(code, mode=DEFAULT_MODE)
2108
2109         args = ["--code", code]
2110         result = CliRunner().invoke(black.main, args)
2111
2112         self.compare_results(result, formatted, 0)
2113
2114     def test_code_option_check(self) -> None:
2115         """Test the code option when check is passed."""
2116         args = ["--check", "--code", 'print("Hello world")\n']
2117         result = CliRunner().invoke(black.main, args)
2118         self.compare_results(result, "", 0)
2119
2120     def test_code_option_check_changed(self) -> None:
2121         """Test the code option when changes are required, and check is passed."""
2122         args = ["--check", "--code", "print('hello world')"]
2123         result = CliRunner().invoke(black.main, args)
2124         self.compare_results(result, "", 1)
2125
2126     def test_code_option_diff(self) -> None:
2127         """Test the code option when diff is passed."""
2128         code = "print('hello world')"
2129         formatted = black.format_str(code, mode=DEFAULT_MODE)
2130         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2131
2132         args = ["--diff", "--code", code]
2133         result = CliRunner().invoke(black.main, args)
2134
2135         # Remove time from diff
2136         output = DIFF_TIME.sub("", result.output)
2137
2138         assert output == result_diff, "The output did not match the expected value."
2139         assert result.exit_code == 0, "The exit code is incorrect."
2140
2141     def test_code_option_color_diff(self) -> None:
2142         """Test the code option when color and diff are passed."""
2143         code = "print('hello world')"
2144         formatted = black.format_str(code, mode=DEFAULT_MODE)
2145
2146         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2147         result_diff = color_diff(result_diff)
2148
2149         args = ["--diff", "--color", "--code", code]
2150         result = CliRunner().invoke(black.main, args)
2151
2152         # Remove time from diff
2153         output = DIFF_TIME.sub("", result.output)
2154
2155         assert output == result_diff, "The output did not match the expected value."
2156         assert result.exit_code == 0, "The exit code is incorrect."
2157
2158     def test_code_option_safe(self) -> None:
2159         """Test that the code option throws an error when the sanity checks fail."""
2160         # Patch black.assert_equivalent to ensure the sanity checks fail
2161         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2162             code = 'print("Hello world")'
2163             error_msg = f"{code}\nerror: cannot format <string>: \n"
2164
2165             args = ["--safe", "--code", code]
2166             result = CliRunner().invoke(black.main, args)
2167
2168             self.compare_results(result, error_msg, 123)
2169
2170     def test_code_option_fast(self) -> None:
2171         """Test that the code option ignores errors when the sanity checks fail."""
2172         # Patch black.assert_equivalent to ensure the sanity checks fail
2173         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2174             code = 'print("Hello world")'
2175             formatted = black.format_str(code, mode=DEFAULT_MODE)
2176
2177             args = ["--fast", "--code", code]
2178             result = CliRunner().invoke(black.main, args)
2179
2180             self.compare_results(result, formatted, 0)
2181
2182     def test_code_option_config(self) -> None:
2183         """
2184         Test that the code option finds the pyproject.toml in the current directory.
2185         """
2186         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2187             # Make sure we are in the project root with the pyproject file
2188             if not Path("tests").exists():
2189                 os.chdir("..")
2190
2191             args = ["--code", "print"]
2192             CliRunner().invoke(black.main, args)
2193
2194             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2195             assert (
2196                 len(parse.mock_calls) >= 1
2197             ), "Expected config parse to be called with the current directory."
2198
2199             _, call_args, _ = parse.mock_calls[0]
2200             assert (
2201                 call_args[0].lower() == str(pyproject_path).lower()
2202             ), "Incorrect config loaded."
2203
2204     def test_code_option_parent_config(self) -> None:
2205         """
2206         Test that the code option finds the pyproject.toml in the parent directory.
2207         """
2208         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2209             # Make sure we are in the tests directory
2210             if Path("tests").exists():
2211                 os.chdir("tests")
2212
2213             args = ["--code", "print"]
2214             CliRunner().invoke(black.main, args)
2215
2216             pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2217             assert (
2218                 len(parse.mock_calls) >= 1
2219             ), "Expected config parse to be called with the current directory."
2220
2221             _, call_args, _ = parse.mock_calls[0]
2222             assert (
2223                 call_args[0].lower() == str(pyproject_path).lower()
2224             ), "Incorrect config loaded."
2225
2226
2227 with open(black.__file__, "r", encoding="utf-8") as _bf:
2228     black_source_lines = _bf.readlines()
2229
2230
2231 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2232     """Show function calls `from black/__init__.py` as they happen.
2233
2234     Register this with `sys.settrace()` in a test you're debugging.
2235     """
2236     if event != "call":
2237         return tracefunc
2238
2239     stack = len(inspect.stack()) - 19
2240     stack *= 2
2241     filename = frame.f_code.co_filename
2242     lineno = frame.f_lineno
2243     func_sig_lineno = lineno - 1
2244     funcname = black_source_lines[func_sig_lineno].strip()
2245     while funcname.startswith("@"):
2246         func_sig_lineno += 1
2247         funcname = black_source_lines[func_sig_lineno].strip()
2248     if "black/__init__.py" in filename:
2249         print(f"{' ' * stack}{lineno}:{funcname}")
2250     return tracefunc
2251
2252
2253 if __name__ == "__main__":
2254     unittest.main(module="test_black")