]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

f0a14aa2da4ead177d897ef9402a33ea1d0e46ea
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     Callable,
20     Dict,
21     List,
22     Iterator,
23     TypeVar,
24 )
25 import pytest
26 import unittest
27 from unittest.mock import patch, MagicMock
28
29 import click
30 from click import unstyle
31 from click.testing import CliRunner
32
33 import black
34 from black import Feature, TargetVersion
35 from black.cache import get_cache_file
36 from black.debug import DebugVisitor
37 from black.output import diff, color_diff
38 from black.report import Report
39 import black.files
40
41 from pathspec import PathSpec
42
43 # Import other test classes
44 from tests.util import (
45     THIS_DIR,
46     read_data,
47     DETERMINISTIC_HEADER,
48     BlackBaseTestCase,
49     DEFAULT_MODE,
50     fs,
51     ff,
52     dump_to_stderr,
53 )
54
55
56 THIS_FILE = Path(__file__)
57 PY36_VERSIONS = {
58     TargetVersion.PY36,
59     TargetVersion.PY37,
60     TargetVersion.PY38,
61     TargetVersion.PY39,
62 }
63 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
64 T = TypeVar("T")
65 R = TypeVar("R")
66
67 # Match the time output in a diff, but nothing else
68 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
69
70
71 @contextmanager
72 def cache_dir(exists: bool = True) -> Iterator[Path]:
73     with TemporaryDirectory() as workspace:
74         cache_dir = Path(workspace)
75         if not exists:
76             cache_dir = cache_dir / "new"
77         with patch("black.cache.CACHE_DIR", cache_dir):
78             yield cache_dir
79
80
81 @contextmanager
82 def event_loop() -> Iterator[None]:
83     policy = asyncio.get_event_loop_policy()
84     loop = policy.new_event_loop()
85     asyncio.set_event_loop(loop)
86     try:
87         yield
88
89     finally:
90         loop.close()
91
92
93 class FakeContext(click.Context):
94     """A fake click Context for when calling functions that need it."""
95
96     def __init__(self) -> None:
97         self.default_map: Dict[str, Any] = {}
98
99
100 class FakeParameter(click.Parameter):
101     """A fake click Parameter for when calling functions that need it."""
102
103     def __init__(self) -> None:
104         pass
105
106
107 class BlackRunner(CliRunner):
108     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
109
110     def __init__(self) -> None:
111         super().__init__(mix_stderr=False)
112
113
114 class BlackTestCase(BlackBaseTestCase):
115     def invokeBlack(
116         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
117     ) -> None:
118         runner = BlackRunner()
119         if ignore_config:
120             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
121         result = runner.invoke(black.main, args)
122         self.assertEqual(
123             result.exit_code,
124             exit_code,
125             msg=(
126                 f"Failed with args: {args}\n"
127                 f"stdout: {result.stdout_bytes.decode()!r}\n"
128                 f"stderr: {result.stderr_bytes.decode()!r}\n"
129                 f"exception: {result.exception}"
130             ),
131         )
132
133     @patch("black.dump_to_file", dump_to_stderr)
134     def test_empty(self) -> None:
135         source = expected = ""
136         actual = fs(source)
137         self.assertFormatEqual(expected, actual)
138         black.assert_equivalent(source, actual)
139         black.assert_stable(source, actual, DEFAULT_MODE)
140
141     def test_empty_ff(self) -> None:
142         expected = ""
143         tmp_file = Path(black.dump_to_file())
144         try:
145             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
146             with open(tmp_file, encoding="utf8") as f:
147                 actual = f.read()
148         finally:
149             os.unlink(tmp_file)
150         self.assertFormatEqual(expected, actual)
151
152     def test_piping(self) -> None:
153         source, expected = read_data("src/black/__init__", data=False)
154         result = BlackRunner().invoke(
155             black.main,
156             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
157             input=BytesIO(source.encode("utf8")),
158         )
159         self.assertEqual(result.exit_code, 0)
160         self.assertFormatEqual(expected, result.output)
161         if source != result.output:
162             black.assert_equivalent(source, result.output)
163             black.assert_stable(source, result.output, DEFAULT_MODE)
164
165     def test_piping_diff(self) -> None:
166         diff_header = re.compile(
167             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
168             r"\+\d\d\d\d"
169         )
170         source, _ = read_data("expression.py")
171         expected, _ = read_data("expression.diff")
172         config = THIS_DIR / "data" / "empty_pyproject.toml"
173         args = [
174             "-",
175             "--fast",
176             f"--line-length={black.DEFAULT_LINE_LENGTH}",
177             "--diff",
178             f"--config={config}",
179         ]
180         result = BlackRunner().invoke(
181             black.main, args, input=BytesIO(source.encode("utf8"))
182         )
183         self.assertEqual(result.exit_code, 0)
184         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
185         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
186         self.assertEqual(expected, actual)
187
188     def test_piping_diff_with_color(self) -> None:
189         source, _ = read_data("expression.py")
190         config = THIS_DIR / "data" / "empty_pyproject.toml"
191         args = [
192             "-",
193             "--fast",
194             f"--line-length={black.DEFAULT_LINE_LENGTH}",
195             "--diff",
196             "--color",
197             f"--config={config}",
198         ]
199         result = BlackRunner().invoke(
200             black.main, args, input=BytesIO(source.encode("utf8"))
201         )
202         actual = result.output
203         # Again, the contents are checked in a different test, so only look for colors.
204         self.assertIn("\033[1;37m", actual)
205         self.assertIn("\033[36m", actual)
206         self.assertIn("\033[32m", actual)
207         self.assertIn("\033[31m", actual)
208         self.assertIn("\033[0m", actual)
209
210     @patch("black.dump_to_file", dump_to_stderr)
211     def _test_wip(self) -> None:
212         source, expected = read_data("wip")
213         sys.settrace(tracefunc)
214         mode = replace(
215             DEFAULT_MODE,
216             experimental_string_processing=False,
217             target_versions={black.TargetVersion.PY38},
218         )
219         actual = fs(source, mode=mode)
220         sys.settrace(None)
221         self.assertFormatEqual(expected, actual)
222         black.assert_equivalent(source, actual)
223         black.assert_stable(source, actual, black.FileMode())
224
225     @unittest.expectedFailure
226     @patch("black.dump_to_file", dump_to_stderr)
227     def test_trailing_comma_optional_parens_stability1(self) -> None:
228         source, _expected = read_data("trailing_comma_optional_parens1")
229         actual = fs(source)
230         black.assert_stable(source, actual, DEFAULT_MODE)
231
232     @unittest.expectedFailure
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_trailing_comma_optional_parens_stability2(self) -> None:
235         source, _expected = read_data("trailing_comma_optional_parens2")
236         actual = fs(source)
237         black.assert_stable(source, actual, DEFAULT_MODE)
238
239     @unittest.expectedFailure
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_trailing_comma_optional_parens_stability3(self) -> None:
242         source, _expected = read_data("trailing_comma_optional_parens3")
243         actual = fs(source)
244         black.assert_stable(source, actual, DEFAULT_MODE)
245
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
248         source, _expected = read_data("trailing_comma_optional_parens1")
249         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
250         black.assert_stable(source, actual, DEFAULT_MODE)
251
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens2")
255         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens3")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_pep_572(self) -> None:
266         source, expected = read_data("pep_572")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_stable(source, actual, DEFAULT_MODE)
270         if sys.version_info >= (3, 8):
271             black.assert_equivalent(source, actual)
272
273     @patch("black.dump_to_file", dump_to_stderr)
274     def test_pep_572_remove_parens(self) -> None:
275         source, expected = read_data("pep_572_remove_parens")
276         actual = fs(source)
277         self.assertFormatEqual(expected, actual)
278         black.assert_stable(source, actual, DEFAULT_MODE)
279         if sys.version_info >= (3, 8):
280             black.assert_equivalent(source, actual)
281
282     @patch("black.dump_to_file", dump_to_stderr)
283     def test_pep_572_do_not_remove_parens(self) -> None:
284         source, expected = read_data("pep_572_do_not_remove_parens")
285         # the AST safety checks will fail, but that's expected, just make sure no
286         # parentheses are touched
287         actual = black.format_str(source, mode=DEFAULT_MODE)
288         self.assertFormatEqual(expected, actual)
289
290     def test_pep_572_version_detection(self) -> None:
291         source, _ = read_data("pep_572")
292         root = black.lib2to3_parse(source)
293         features = black.get_features_used(root)
294         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
295         versions = black.detect_target_versions(root)
296         self.assertIn(black.TargetVersion.PY38, versions)
297
298     def test_expression_ff(self) -> None:
299         source, expected = read_data("expression")
300         tmp_file = Path(black.dump_to_file(source))
301         try:
302             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
303             with open(tmp_file, encoding="utf8") as f:
304                 actual = f.read()
305         finally:
306             os.unlink(tmp_file)
307         self.assertFormatEqual(expected, actual)
308         with patch("black.dump_to_file", dump_to_stderr):
309             black.assert_equivalent(source, actual)
310             black.assert_stable(source, actual, DEFAULT_MODE)
311
312     def test_expression_diff(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         diff_header = re.compile(
318             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
319             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
320         )
321         try:
322             result = BlackRunner().invoke(
323                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
324             )
325             self.assertEqual(result.exit_code, 0)
326         finally:
327             os.unlink(tmp_file)
328         actual = result.output
329         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
330         if expected != actual:
331             dump = black.dump_to_file(actual)
332             msg = (
333                 "Expected diff isn't equal to the actual. If you made changes to"
334                 " expression.py and this is an anticipated difference, overwrite"
335                 f" tests/data/expression.diff with {dump}"
336             )
337             self.assertEqual(expected, actual, msg)
338
339     def test_expression_diff_with_color(self) -> None:
340         source, _ = read_data("expression.py")
341         config = THIS_DIR / "data" / "empty_pyproject.toml"
342         expected, _ = read_data("expression.diff")
343         tmp_file = Path(black.dump_to_file(source))
344         try:
345             result = BlackRunner().invoke(
346                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
347             )
348         finally:
349             os.unlink(tmp_file)
350         actual = result.output
351         # We check the contents of the diff in `test_expression_diff`. All
352         # we need to check here is that color codes exist in the result.
353         self.assertIn("\033[1;37m", actual)
354         self.assertIn("\033[36m", actual)
355         self.assertIn("\033[32m", actual)
356         self.assertIn("\033[31m", actual)
357         self.assertIn("\033[0m", actual)
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_pep_570(self) -> None:
361         source, expected = read_data("pep_570")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_stable(source, actual, DEFAULT_MODE)
365         if sys.version_info >= (3, 8):
366             black.assert_equivalent(source, actual)
367
368     def test_detect_pos_only_arguments(self) -> None:
369         source, _ = read_data("pep_570")
370         root = black.lib2to3_parse(source)
371         features = black.get_features_used(root)
372         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
373         versions = black.detect_target_versions(root)
374         self.assertIn(black.TargetVersion.PY38, versions)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_string_quotes(self) -> None:
378         source, expected = read_data("string_quotes")
379         mode = black.Mode(experimental_string_processing=True)
380         actual = fs(source, mode=mode)
381         self.assertFormatEqual(expected, actual)
382         black.assert_equivalent(source, actual)
383         black.assert_stable(source, actual, mode)
384         mode = replace(mode, string_normalization=False)
385         not_normalized = fs(source, mode=mode)
386         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
387         black.assert_equivalent(source, not_normalized)
388         black.assert_stable(source, not_normalized, mode=mode)
389
390     @patch("black.dump_to_file", dump_to_stderr)
391     def test_docstring_no_string_normalization(self) -> None:
392         """Like test_docstring but with string normalization off."""
393         source, expected = read_data("docstring_no_string_normalization")
394         mode = replace(DEFAULT_MODE, string_normalization=False)
395         actual = fs(source, mode=mode)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, mode)
399
400     def test_long_strings_flag_disabled(self) -> None:
401         """Tests for turning off the string processing logic."""
402         source, expected = read_data("long_strings_flag_disabled")
403         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
404         actual = fs(source, mode=mode)
405         self.assertFormatEqual(expected, actual)
406         black.assert_stable(expected, actual, mode)
407
408     @patch("black.dump_to_file", dump_to_stderr)
409     def test_numeric_literals(self) -> None:
410         source, expected = read_data("numeric_literals")
411         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
412         actual = fs(source, mode=mode)
413         self.assertFormatEqual(expected, actual)
414         black.assert_equivalent(source, actual)
415         black.assert_stable(source, actual, mode)
416
417     @patch("black.dump_to_file", dump_to_stderr)
418     def test_numeric_literals_ignoring_underscores(self) -> None:
419         source, expected = read_data("numeric_literals_skip_underscores")
420         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
421         actual = fs(source, mode=mode)
422         self.assertFormatEqual(expected, actual)
423         black.assert_equivalent(source, actual)
424         black.assert_stable(source, actual, mode)
425
426     def test_skip_magic_trailing_comma(self) -> None:
427         source, _ = read_data("expression.py")
428         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
429         tmp_file = Path(black.dump_to_file(source))
430         diff_header = re.compile(
431             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
432             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
433         )
434         try:
435             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
436             self.assertEqual(result.exit_code, 0)
437         finally:
438             os.unlink(tmp_file)
439         actual = result.output
440         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
441         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
442         if expected != actual:
443             dump = black.dump_to_file(actual)
444             msg = (
445                 "Expected diff isn't equal to the actual. If you made changes to"
446                 " expression.py and this is an anticipated difference, overwrite"
447                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
448             )
449             self.assertEqual(expected, actual, msg)
450
451     @pytest.mark.no_python2
452     def test_python2_should_fail_without_optional_install(self) -> None:
453         if sys.version_info < (3, 8):
454             self.skipTest(
455                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
456                 " able to parse Python 2 syntax without explicitly specifying the"
457                 " python2 extra"
458             )
459
460         source = "x = 1234l"
461         tmp_file = Path(black.dump_to_file(source))
462         try:
463             runner = BlackRunner()
464             result = runner.invoke(black.main, [str(tmp_file)])
465             self.assertEqual(result.exit_code, 123)
466         finally:
467             os.unlink(tmp_file)
468         actual = (
469             result.stderr_bytes.decode()
470             .replace("\n", "")
471             .replace("\\n", "")
472             .replace("\\r", "")
473             .replace("\r", "")
474         )
475         msg = (
476             "The requested source code has invalid Python 3 syntax."
477             "If you are trying to format Python 2 files please reinstall Black"
478             " with the 'python2' extra: `python3 -m pip install black[python2]`."
479         )
480         self.assertIn(msg, actual)
481
482     @pytest.mark.python2
483     @patch("black.dump_to_file", dump_to_stderr)
484     def test_python2_print_function(self) -> None:
485         source, expected = read_data("python2_print_function")
486         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
487         actual = fs(source, mode=mode)
488         self.assertFormatEqual(expected, actual)
489         black.assert_equivalent(source, actual)
490         black.assert_stable(source, actual, mode)
491
492     @patch("black.dump_to_file", dump_to_stderr)
493     def test_stub(self) -> None:
494         mode = replace(DEFAULT_MODE, is_pyi=True)
495         source, expected = read_data("stub.pyi")
496         actual = fs(source, mode=mode)
497         self.assertFormatEqual(expected, actual)
498         black.assert_stable(source, actual, mode)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_async_as_identifier(self) -> None:
502         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
503         source, expected = read_data("async_as_identifier")
504         actual = fs(source)
505         self.assertFormatEqual(expected, actual)
506         major, minor = sys.version_info[:2]
507         if major < 3 or (major <= 3 and minor < 7):
508             black.assert_equivalent(source, actual)
509         black.assert_stable(source, actual, DEFAULT_MODE)
510         # ensure black can parse this when the target is 3.6
511         self.invokeBlack([str(source_path), "--target-version", "py36"])
512         # but not on 3.7, because async/await is no longer an identifier
513         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_python37(self) -> None:
517         source_path = (THIS_DIR / "data" / "python37.py").resolve()
518         source, expected = read_data("python37")
519         actual = fs(source)
520         self.assertFormatEqual(expected, actual)
521         major, minor = sys.version_info[:2]
522         if major > 3 or (major == 3 and minor >= 7):
523             black.assert_equivalent(source, actual)
524         black.assert_stable(source, actual, DEFAULT_MODE)
525         # ensure black can parse this when the target is 3.7
526         self.invokeBlack([str(source_path), "--target-version", "py37"])
527         # but not on 3.6, because we use async as a reserved keyword
528         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_python38(self) -> None:
532         source, expected = read_data("python38")
533         actual = fs(source)
534         self.assertFormatEqual(expected, actual)
535         major, minor = sys.version_info[:2]
536         if major > 3 or (major == 3 and minor >= 8):
537             black.assert_equivalent(source, actual)
538         black.assert_stable(source, actual, DEFAULT_MODE)
539
540     @patch("black.dump_to_file", dump_to_stderr)
541     def test_python39(self) -> None:
542         source, expected = read_data("python39")
543         actual = fs(source)
544         self.assertFormatEqual(expected, actual)
545         major, minor = sys.version_info[:2]
546         if major > 3 or (major == 3 and minor >= 9):
547             black.assert_equivalent(source, actual)
548         black.assert_stable(source, actual, DEFAULT_MODE)
549
550     def test_tab_comment_indentation(self) -> None:
551         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
552         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
553         self.assertFormatEqual(contents_spc, fs(contents_spc))
554         self.assertFormatEqual(contents_spc, fs(contents_tab))
555
556         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
557         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
558         self.assertFormatEqual(contents_spc, fs(contents_spc))
559         self.assertFormatEqual(contents_spc, fs(contents_tab))
560
561         # mixed tabs and spaces (valid Python 2 code)
562         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
563         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
564         self.assertFormatEqual(contents_spc, fs(contents_spc))
565         self.assertFormatEqual(contents_spc, fs(contents_tab))
566
567         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
568         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
569         self.assertFormatEqual(contents_spc, fs(contents_spc))
570         self.assertFormatEqual(contents_spc, fs(contents_tab))
571
572     def test_report_verbose(self) -> None:
573         report = Report(verbose=True)
574         out_lines = []
575         err_lines = []
576
577         def out(msg: str, **kwargs: Any) -> None:
578             out_lines.append(msg)
579
580         def err(msg: str, **kwargs: Any) -> None:
581             err_lines.append(msg)
582
583         with patch("black.output._out", out), patch("black.output._err", err):
584             report.done(Path("f1"), black.Changed.NO)
585             self.assertEqual(len(out_lines), 1)
586             self.assertEqual(len(err_lines), 0)
587             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
588             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
589             self.assertEqual(report.return_code, 0)
590             report.done(Path("f2"), black.Changed.YES)
591             self.assertEqual(len(out_lines), 2)
592             self.assertEqual(len(err_lines), 0)
593             self.assertEqual(out_lines[-1], "reformatted f2")
594             self.assertEqual(
595                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
596             )
597             report.done(Path("f3"), black.Changed.CACHED)
598             self.assertEqual(len(out_lines), 3)
599             self.assertEqual(len(err_lines), 0)
600             self.assertEqual(
601                 out_lines[-1], "f3 wasn't modified on disk since last run."
602             )
603             self.assertEqual(
604                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
605             )
606             self.assertEqual(report.return_code, 0)
607             report.check = True
608             self.assertEqual(report.return_code, 1)
609             report.check = False
610             report.failed(Path("e1"), "boom")
611             self.assertEqual(len(out_lines), 3)
612             self.assertEqual(len(err_lines), 1)
613             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
614             self.assertEqual(
615                 unstyle(str(report)),
616                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
617                 " reformat.",
618             )
619             self.assertEqual(report.return_code, 123)
620             report.done(Path("f3"), black.Changed.YES)
621             self.assertEqual(len(out_lines), 4)
622             self.assertEqual(len(err_lines), 1)
623             self.assertEqual(out_lines[-1], "reformatted f3")
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
627                 " reformat.",
628             )
629             self.assertEqual(report.return_code, 123)
630             report.failed(Path("e2"), "boom")
631             self.assertEqual(len(out_lines), 4)
632             self.assertEqual(len(err_lines), 2)
633             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
634             self.assertEqual(
635                 unstyle(str(report)),
636                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
637                 " reformat.",
638             )
639             self.assertEqual(report.return_code, 123)
640             report.path_ignored(Path("wat"), "no match")
641             self.assertEqual(len(out_lines), 5)
642             self.assertEqual(len(err_lines), 2)
643             self.assertEqual(out_lines[-1], "wat ignored: no match")
644             self.assertEqual(
645                 unstyle(str(report)),
646                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
647                 " reformat.",
648             )
649             self.assertEqual(report.return_code, 123)
650             report.done(Path("f4"), black.Changed.NO)
651             self.assertEqual(len(out_lines), 6)
652             self.assertEqual(len(err_lines), 2)
653             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
654             self.assertEqual(
655                 unstyle(str(report)),
656                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
657                 " reformat.",
658             )
659             self.assertEqual(report.return_code, 123)
660             report.check = True
661             self.assertEqual(
662                 unstyle(str(report)),
663                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
664                 " would fail to reformat.",
665             )
666             report.check = False
667             report.diff = True
668             self.assertEqual(
669                 unstyle(str(report)),
670                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
671                 " would fail to reformat.",
672             )
673
674     def test_report_quiet(self) -> None:
675         report = Report(quiet=True)
676         out_lines = []
677         err_lines = []
678
679         def out(msg: str, **kwargs: Any) -> None:
680             out_lines.append(msg)
681
682         def err(msg: str, **kwargs: Any) -> None:
683             err_lines.append(msg)
684
685         with patch("black.output._out", out), patch("black.output._err", err):
686             report.done(Path("f1"), black.Changed.NO)
687             self.assertEqual(len(out_lines), 0)
688             self.assertEqual(len(err_lines), 0)
689             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
690             self.assertEqual(report.return_code, 0)
691             report.done(Path("f2"), black.Changed.YES)
692             self.assertEqual(len(out_lines), 0)
693             self.assertEqual(len(err_lines), 0)
694             self.assertEqual(
695                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
696             )
697             report.done(Path("f3"), black.Changed.CACHED)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 0)
700             self.assertEqual(
701                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
702             )
703             self.assertEqual(report.return_code, 0)
704             report.check = True
705             self.assertEqual(report.return_code, 1)
706             report.check = False
707             report.failed(Path("e1"), "boom")
708             self.assertEqual(len(out_lines), 0)
709             self.assertEqual(len(err_lines), 1)
710             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
711             self.assertEqual(
712                 unstyle(str(report)),
713                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
714                 " reformat.",
715             )
716             self.assertEqual(report.return_code, 123)
717             report.done(Path("f3"), black.Changed.YES)
718             self.assertEqual(len(out_lines), 0)
719             self.assertEqual(len(err_lines), 1)
720             self.assertEqual(
721                 unstyle(str(report)),
722                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
723                 " reformat.",
724             )
725             self.assertEqual(report.return_code, 123)
726             report.failed(Path("e2"), "boom")
727             self.assertEqual(len(out_lines), 0)
728             self.assertEqual(len(err_lines), 2)
729             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
730             self.assertEqual(
731                 unstyle(str(report)),
732                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
733                 " reformat.",
734             )
735             self.assertEqual(report.return_code, 123)
736             report.path_ignored(Path("wat"), "no match")
737             self.assertEqual(len(out_lines), 0)
738             self.assertEqual(len(err_lines), 2)
739             self.assertEqual(
740                 unstyle(str(report)),
741                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
742                 " reformat.",
743             )
744             self.assertEqual(report.return_code, 123)
745             report.done(Path("f4"), black.Changed.NO)
746             self.assertEqual(len(out_lines), 0)
747             self.assertEqual(len(err_lines), 2)
748             self.assertEqual(
749                 unstyle(str(report)),
750                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
751                 " reformat.",
752             )
753             self.assertEqual(report.return_code, 123)
754             report.check = True
755             self.assertEqual(
756                 unstyle(str(report)),
757                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
758                 " would fail to reformat.",
759             )
760             report.check = False
761             report.diff = True
762             self.assertEqual(
763                 unstyle(str(report)),
764                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
765                 " would fail to reformat.",
766             )
767
768     def test_report_normal(self) -> None:
769         report = black.Report()
770         out_lines = []
771         err_lines = []
772
773         def out(msg: str, **kwargs: Any) -> None:
774             out_lines.append(msg)
775
776         def err(msg: str, **kwargs: Any) -> None:
777             err_lines.append(msg)
778
779         with patch("black.output._out", out), patch("black.output._err", err):
780             report.done(Path("f1"), black.Changed.NO)
781             self.assertEqual(len(out_lines), 0)
782             self.assertEqual(len(err_lines), 0)
783             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
784             self.assertEqual(report.return_code, 0)
785             report.done(Path("f2"), black.Changed.YES)
786             self.assertEqual(len(out_lines), 1)
787             self.assertEqual(len(err_lines), 0)
788             self.assertEqual(out_lines[-1], "reformatted f2")
789             self.assertEqual(
790                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
791             )
792             report.done(Path("f3"), black.Changed.CACHED)
793             self.assertEqual(len(out_lines), 1)
794             self.assertEqual(len(err_lines), 0)
795             self.assertEqual(out_lines[-1], "reformatted f2")
796             self.assertEqual(
797                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
798             )
799             self.assertEqual(report.return_code, 0)
800             report.check = True
801             self.assertEqual(report.return_code, 1)
802             report.check = False
803             report.failed(Path("e1"), "boom")
804             self.assertEqual(len(out_lines), 1)
805             self.assertEqual(len(err_lines), 1)
806             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
807             self.assertEqual(
808                 unstyle(str(report)),
809                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
810                 " reformat.",
811             )
812             self.assertEqual(report.return_code, 123)
813             report.done(Path("f3"), black.Changed.YES)
814             self.assertEqual(len(out_lines), 2)
815             self.assertEqual(len(err_lines), 1)
816             self.assertEqual(out_lines[-1], "reformatted f3")
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
820                 " reformat.",
821             )
822             self.assertEqual(report.return_code, 123)
823             report.failed(Path("e2"), "boom")
824             self.assertEqual(len(out_lines), 2)
825             self.assertEqual(len(err_lines), 2)
826             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
827             self.assertEqual(
828                 unstyle(str(report)),
829                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
830                 " reformat.",
831             )
832             self.assertEqual(report.return_code, 123)
833             report.path_ignored(Path("wat"), "no match")
834             self.assertEqual(len(out_lines), 2)
835             self.assertEqual(len(err_lines), 2)
836             self.assertEqual(
837                 unstyle(str(report)),
838                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
839                 " reformat.",
840             )
841             self.assertEqual(report.return_code, 123)
842             report.done(Path("f4"), black.Changed.NO)
843             self.assertEqual(len(out_lines), 2)
844             self.assertEqual(len(err_lines), 2)
845             self.assertEqual(
846                 unstyle(str(report)),
847                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
848                 " reformat.",
849             )
850             self.assertEqual(report.return_code, 123)
851             report.check = True
852             self.assertEqual(
853                 unstyle(str(report)),
854                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
855                 " would fail to reformat.",
856             )
857             report.check = False
858             report.diff = True
859             self.assertEqual(
860                 unstyle(str(report)),
861                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
862                 " would fail to reformat.",
863             )
864
865     def test_lib2to3_parse(self) -> None:
866         with self.assertRaises(black.InvalidInput):
867             black.lib2to3_parse("invalid syntax")
868
869         straddling = "x + y"
870         black.lib2to3_parse(straddling)
871         black.lib2to3_parse(straddling, {TargetVersion.PY27})
872         black.lib2to3_parse(straddling, {TargetVersion.PY36})
873         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
874
875         py2_only = "print x"
876         black.lib2to3_parse(py2_only)
877         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
878         with self.assertRaises(black.InvalidInput):
879             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
880         with self.assertRaises(black.InvalidInput):
881             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
882
883         py3_only = "exec(x, end=y)"
884         black.lib2to3_parse(py3_only)
885         with self.assertRaises(black.InvalidInput):
886             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
887         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
888         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
889
890     def test_get_features_used_decorator(self) -> None:
891         # Test the feature detection of new decorator syntax
892         # since this makes some test cases of test_get_features_used()
893         # fails if it fails, this is tested first so that a useful case
894         # is identified
895         simples, relaxed = read_data("decorators")
896         # skip explanation comments at the top of the file
897         for simple_test in simples.split("##")[1:]:
898             node = black.lib2to3_parse(simple_test)
899             decorator = str(node.children[0].children[0]).strip()
900             self.assertNotIn(
901                 Feature.RELAXED_DECORATORS,
902                 black.get_features_used(node),
903                 msg=(
904                     f"decorator '{decorator}' follows python<=3.8 syntax"
905                     "but is detected as 3.9+"
906                     # f"The full node is\n{node!r}"
907                 ),
908             )
909         # skip the '# output' comment at the top of the output part
910         for relaxed_test in relaxed.split("##")[1:]:
911             node = black.lib2to3_parse(relaxed_test)
912             decorator = str(node.children[0].children[0]).strip()
913             self.assertIn(
914                 Feature.RELAXED_DECORATORS,
915                 black.get_features_used(node),
916                 msg=(
917                     f"decorator '{decorator}' uses python3.9+ syntax"
918                     "but is detected as python<=3.8"
919                     # f"The full node is\n{node!r}"
920                 ),
921             )
922
923     def test_get_features_used(self) -> None:
924         node = black.lib2to3_parse("def f(*, arg): ...\n")
925         self.assertEqual(black.get_features_used(node), set())
926         node = black.lib2to3_parse("def f(*, arg,): ...\n")
927         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
928         node = black.lib2to3_parse("f(*arg,)\n")
929         self.assertEqual(
930             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
931         )
932         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
933         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
934         node = black.lib2to3_parse("123_456\n")
935         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
936         node = black.lib2to3_parse("123456\n")
937         self.assertEqual(black.get_features_used(node), set())
938         source, expected = read_data("function")
939         node = black.lib2to3_parse(source)
940         expected_features = {
941             Feature.TRAILING_COMMA_IN_CALL,
942             Feature.TRAILING_COMMA_IN_DEF,
943             Feature.F_STRINGS,
944         }
945         self.assertEqual(black.get_features_used(node), expected_features)
946         node = black.lib2to3_parse(expected)
947         self.assertEqual(black.get_features_used(node), expected_features)
948         source, expected = read_data("expression")
949         node = black.lib2to3_parse(source)
950         self.assertEqual(black.get_features_used(node), set())
951         node = black.lib2to3_parse(expected)
952         self.assertEqual(black.get_features_used(node), set())
953
954     def test_get_future_imports(self) -> None:
955         node = black.lib2to3_parse("\n")
956         self.assertEqual(set(), black.get_future_imports(node))
957         node = black.lib2to3_parse("from __future__ import black\n")
958         self.assertEqual({"black"}, black.get_future_imports(node))
959         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
960         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
961         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
962         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
963         node = black.lib2to3_parse(
964             "from __future__ import multiple\nfrom __future__ import imports\n"
965         )
966         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
968         self.assertEqual({"black"}, black.get_future_imports(node))
969         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
970         self.assertEqual({"black"}, black.get_future_imports(node))
971         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
972         self.assertEqual(set(), black.get_future_imports(node))
973         node = black.lib2to3_parse("from some.module import black\n")
974         self.assertEqual(set(), black.get_future_imports(node))
975         node = black.lib2to3_parse(
976             "from __future__ import unicode_literals as _unicode_literals"
977         )
978         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
979         node = black.lib2to3_parse(
980             "from __future__ import unicode_literals as _lol, print"
981         )
982         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
983
984     def test_debug_visitor(self) -> None:
985         source, _ = read_data("debug_visitor.py")
986         expected, _ = read_data("debug_visitor.out")
987         out_lines = []
988         err_lines = []
989
990         def out(msg: str, **kwargs: Any) -> None:
991             out_lines.append(msg)
992
993         def err(msg: str, **kwargs: Any) -> None:
994             err_lines.append(msg)
995
996         with patch("black.debug.out", out):
997             DebugVisitor.show(source)
998         actual = "\n".join(out_lines) + "\n"
999         log_name = ""
1000         if expected != actual:
1001             log_name = black.dump_to_file(*out_lines)
1002         self.assertEqual(
1003             expected,
1004             actual,
1005             f"AST print out is different. Actual version dumped to {log_name}",
1006         )
1007
1008     def test_format_file_contents(self) -> None:
1009         empty = ""
1010         mode = DEFAULT_MODE
1011         with self.assertRaises(black.NothingChanged):
1012             black.format_file_contents(empty, mode=mode, fast=False)
1013         just_nl = "\n"
1014         with self.assertRaises(black.NothingChanged):
1015             black.format_file_contents(just_nl, mode=mode, fast=False)
1016         same = "j = [1, 2, 3]\n"
1017         with self.assertRaises(black.NothingChanged):
1018             black.format_file_contents(same, mode=mode, fast=False)
1019         different = "j = [1,2,3]"
1020         expected = same
1021         actual = black.format_file_contents(different, mode=mode, fast=False)
1022         self.assertEqual(expected, actual)
1023         invalid = "return if you can"
1024         with self.assertRaises(black.InvalidInput) as e:
1025             black.format_file_contents(invalid, mode=mode, fast=False)
1026         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1027
1028     def test_endmarker(self) -> None:
1029         n = black.lib2to3_parse("\n")
1030         self.assertEqual(n.type, black.syms.file_input)
1031         self.assertEqual(len(n.children), 1)
1032         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1033
1034     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1035     def test_assertFormatEqual(self) -> None:
1036         out_lines = []
1037         err_lines = []
1038
1039         def out(msg: str, **kwargs: Any) -> None:
1040             out_lines.append(msg)
1041
1042         def err(msg: str, **kwargs: Any) -> None:
1043             err_lines.append(msg)
1044
1045         with patch("black.output._out", out), patch("black.output._err", err):
1046             with self.assertRaises(AssertionError):
1047                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1048
1049         out_str = "".join(out_lines)
1050         self.assertTrue("Expected tree:" in out_str)
1051         self.assertTrue("Actual tree:" in out_str)
1052         self.assertEqual("".join(err_lines), "")
1053
1054     def test_cache_broken_file(self) -> None:
1055         mode = DEFAULT_MODE
1056         with cache_dir() as workspace:
1057             cache_file = get_cache_file(mode)
1058             with cache_file.open("w") as fobj:
1059                 fobj.write("this is not a pickle")
1060             self.assertEqual(black.read_cache(mode), {})
1061             src = (workspace / "test.py").resolve()
1062             with src.open("w") as fobj:
1063                 fobj.write("print('hello')")
1064             self.invokeBlack([str(src)])
1065             cache = black.read_cache(mode)
1066             self.assertIn(str(src), cache)
1067
1068     def test_cache_single_file_already_cached(self) -> None:
1069         mode = DEFAULT_MODE
1070         with cache_dir() as workspace:
1071             src = (workspace / "test.py").resolve()
1072             with src.open("w") as fobj:
1073                 fobj.write("print('hello')")
1074             black.write_cache({}, [src], mode)
1075             self.invokeBlack([str(src)])
1076             with src.open("r") as fobj:
1077                 self.assertEqual(fobj.read(), "print('hello')")
1078
1079     @event_loop()
1080     def test_cache_multiple_files(self) -> None:
1081         mode = DEFAULT_MODE
1082         with cache_dir() as workspace, patch(
1083             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1084         ):
1085             one = (workspace / "one.py").resolve()
1086             with one.open("w") as fobj:
1087                 fobj.write("print('hello')")
1088             two = (workspace / "two.py").resolve()
1089             with two.open("w") as fobj:
1090                 fobj.write("print('hello')")
1091             black.write_cache({}, [one], mode)
1092             self.invokeBlack([str(workspace)])
1093             with one.open("r") as fobj:
1094                 self.assertEqual(fobj.read(), "print('hello')")
1095             with two.open("r") as fobj:
1096                 self.assertEqual(fobj.read(), 'print("hello")\n')
1097             cache = black.read_cache(mode)
1098             self.assertIn(str(one), cache)
1099             self.assertIn(str(two), cache)
1100
1101     def test_no_cache_when_writeback_diff(self) -> None:
1102         mode = DEFAULT_MODE
1103         with cache_dir() as workspace:
1104             src = (workspace / "test.py").resolve()
1105             with src.open("w") as fobj:
1106                 fobj.write("print('hello')")
1107             with patch("black.read_cache") as read_cache, patch(
1108                 "black.write_cache"
1109             ) as write_cache:
1110                 self.invokeBlack([str(src), "--diff"])
1111                 cache_file = get_cache_file(mode)
1112                 self.assertFalse(cache_file.exists())
1113                 write_cache.assert_not_called()
1114                 read_cache.assert_not_called()
1115
1116     def test_no_cache_when_writeback_color_diff(self) -> None:
1117         mode = DEFAULT_MODE
1118         with cache_dir() as workspace:
1119             src = (workspace / "test.py").resolve()
1120             with src.open("w") as fobj:
1121                 fobj.write("print('hello')")
1122             with patch("black.read_cache") as read_cache, patch(
1123                 "black.write_cache"
1124             ) as write_cache:
1125                 self.invokeBlack([str(src), "--diff", "--color"])
1126                 cache_file = get_cache_file(mode)
1127                 self.assertFalse(cache_file.exists())
1128                 write_cache.assert_not_called()
1129                 read_cache.assert_not_called()
1130
1131     @event_loop()
1132     def test_output_locking_when_writeback_diff(self) -> None:
1133         with cache_dir() as workspace:
1134             for tag in range(0, 4):
1135                 src = (workspace / f"test{tag}.py").resolve()
1136                 with src.open("w") as fobj:
1137                     fobj.write("print('hello')")
1138             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1139                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1140                 # this isn't quite doing what we want, but if it _isn't_
1141                 # called then we cannot be using the lock it provides
1142                 mgr.assert_called()
1143
1144     @event_loop()
1145     def test_output_locking_when_writeback_color_diff(self) -> None:
1146         with cache_dir() as workspace:
1147             for tag in range(0, 4):
1148                 src = (workspace / f"test{tag}.py").resolve()
1149                 with src.open("w") as fobj:
1150                     fobj.write("print('hello')")
1151             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1152                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1153                 # this isn't quite doing what we want, but if it _isn't_
1154                 # called then we cannot be using the lock it provides
1155                 mgr.assert_called()
1156
1157     def test_no_cache_when_stdin(self) -> None:
1158         mode = DEFAULT_MODE
1159         with cache_dir():
1160             result = CliRunner().invoke(
1161                 black.main, ["-"], input=BytesIO(b"print('hello')")
1162             )
1163             self.assertEqual(result.exit_code, 0)
1164             cache_file = get_cache_file(mode)
1165             self.assertFalse(cache_file.exists())
1166
1167     def test_read_cache_no_cachefile(self) -> None:
1168         mode = DEFAULT_MODE
1169         with cache_dir():
1170             self.assertEqual(black.read_cache(mode), {})
1171
1172     def test_write_cache_read_cache(self) -> None:
1173         mode = DEFAULT_MODE
1174         with cache_dir() as workspace:
1175             src = (workspace / "test.py").resolve()
1176             src.touch()
1177             black.write_cache({}, [src], mode)
1178             cache = black.read_cache(mode)
1179             self.assertIn(str(src), cache)
1180             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1181
1182     def test_filter_cached(self) -> None:
1183         with TemporaryDirectory() as workspace:
1184             path = Path(workspace)
1185             uncached = (path / "uncached").resolve()
1186             cached = (path / "cached").resolve()
1187             cached_but_changed = (path / "changed").resolve()
1188             uncached.touch()
1189             cached.touch()
1190             cached_but_changed.touch()
1191             cache = {
1192                 str(cached): black.get_cache_info(cached),
1193                 str(cached_but_changed): (0.0, 0),
1194             }
1195             todo, done = black.filter_cached(
1196                 cache, {uncached, cached, cached_but_changed}
1197             )
1198             self.assertEqual(todo, {uncached, cached_but_changed})
1199             self.assertEqual(done, {cached})
1200
1201     def test_write_cache_creates_directory_if_needed(self) -> None:
1202         mode = DEFAULT_MODE
1203         with cache_dir(exists=False) as workspace:
1204             self.assertFalse(workspace.exists())
1205             black.write_cache({}, [], mode)
1206             self.assertTrue(workspace.exists())
1207
1208     @event_loop()
1209     def test_failed_formatting_does_not_get_cached(self) -> None:
1210         mode = DEFAULT_MODE
1211         with cache_dir() as workspace, patch(
1212             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1213         ):
1214             failing = (workspace / "failing.py").resolve()
1215             with failing.open("w") as fobj:
1216                 fobj.write("not actually python")
1217             clean = (workspace / "clean.py").resolve()
1218             with clean.open("w") as fobj:
1219                 fobj.write('print("hello")\n')
1220             self.invokeBlack([str(workspace)], exit_code=123)
1221             cache = black.read_cache(mode)
1222             self.assertNotIn(str(failing), cache)
1223             self.assertIn(str(clean), cache)
1224
1225     def test_write_cache_write_fail(self) -> None:
1226         mode = DEFAULT_MODE
1227         with cache_dir(), patch.object(Path, "open") as mock:
1228             mock.side_effect = OSError
1229             black.write_cache({}, [], mode)
1230
1231     @event_loop()
1232     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1233     def test_works_in_mono_process_only_environment(self) -> None:
1234         with cache_dir() as workspace:
1235             for f in [
1236                 (workspace / "one.py").resolve(),
1237                 (workspace / "two.py").resolve(),
1238             ]:
1239                 f.write_text('print("hello")\n')
1240             self.invokeBlack([str(workspace)])
1241
1242     @event_loop()
1243     def test_check_diff_use_together(self) -> None:
1244         with cache_dir():
1245             # Files which will be reformatted.
1246             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1247             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1248             # Files which will not be reformatted.
1249             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1250             self.invokeBlack([str(src2), "--diff", "--check"])
1251             # Multi file command.
1252             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1253
1254     def test_no_files(self) -> None:
1255         with cache_dir():
1256             # Without an argument, black exits with error code 0.
1257             self.invokeBlack([])
1258
1259     def test_broken_symlink(self) -> None:
1260         with cache_dir() as workspace:
1261             symlink = workspace / "broken_link.py"
1262             try:
1263                 symlink.symlink_to("nonexistent.py")
1264             except OSError as e:
1265                 self.skipTest(f"Can't create symlinks: {e}")
1266             self.invokeBlack([str(workspace.resolve())])
1267
1268     def test_read_cache_line_lengths(self) -> None:
1269         mode = DEFAULT_MODE
1270         short_mode = replace(DEFAULT_MODE, line_length=1)
1271         with cache_dir() as workspace:
1272             path = (workspace / "file.py").resolve()
1273             path.touch()
1274             black.write_cache({}, [path], mode)
1275             one = black.read_cache(mode)
1276             self.assertIn(str(path), one)
1277             two = black.read_cache(short_mode)
1278             self.assertNotIn(str(path), two)
1279
1280     def test_single_file_force_pyi(self) -> None:
1281         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1282         contents, expected = read_data("force_pyi")
1283         with cache_dir() as workspace:
1284             path = (workspace / "file.py").resolve()
1285             with open(path, "w") as fh:
1286                 fh.write(contents)
1287             self.invokeBlack([str(path), "--pyi"])
1288             with open(path, "r") as fh:
1289                 actual = fh.read()
1290             # verify cache with --pyi is separate
1291             pyi_cache = black.read_cache(pyi_mode)
1292             self.assertIn(str(path), pyi_cache)
1293             normal_cache = black.read_cache(DEFAULT_MODE)
1294             self.assertNotIn(str(path), normal_cache)
1295         self.assertFormatEqual(expected, actual)
1296         black.assert_equivalent(contents, actual)
1297         black.assert_stable(contents, actual, pyi_mode)
1298
1299     @event_loop()
1300     def test_multi_file_force_pyi(self) -> None:
1301         reg_mode = DEFAULT_MODE
1302         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1303         contents, expected = read_data("force_pyi")
1304         with cache_dir() as workspace:
1305             paths = [
1306                 (workspace / "file1.py").resolve(),
1307                 (workspace / "file2.py").resolve(),
1308             ]
1309             for path in paths:
1310                 with open(path, "w") as fh:
1311                     fh.write(contents)
1312             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1313             for path in paths:
1314                 with open(path, "r") as fh:
1315                     actual = fh.read()
1316                 self.assertEqual(actual, expected)
1317             # verify cache with --pyi is separate
1318             pyi_cache = black.read_cache(pyi_mode)
1319             normal_cache = black.read_cache(reg_mode)
1320             for path in paths:
1321                 self.assertIn(str(path), pyi_cache)
1322                 self.assertNotIn(str(path), normal_cache)
1323
1324     def test_pipe_force_pyi(self) -> None:
1325         source, expected = read_data("force_pyi")
1326         result = CliRunner().invoke(
1327             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1328         )
1329         self.assertEqual(result.exit_code, 0)
1330         actual = result.output
1331         self.assertFormatEqual(actual, expected)
1332
1333     def test_single_file_force_py36(self) -> None:
1334         reg_mode = DEFAULT_MODE
1335         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1336         source, expected = read_data("force_py36")
1337         with cache_dir() as workspace:
1338             path = (workspace / "file.py").resolve()
1339             with open(path, "w") as fh:
1340                 fh.write(source)
1341             self.invokeBlack([str(path), *PY36_ARGS])
1342             with open(path, "r") as fh:
1343                 actual = fh.read()
1344             # verify cache with --target-version is separate
1345             py36_cache = black.read_cache(py36_mode)
1346             self.assertIn(str(path), py36_cache)
1347             normal_cache = black.read_cache(reg_mode)
1348             self.assertNotIn(str(path), normal_cache)
1349         self.assertEqual(actual, expected)
1350
1351     @event_loop()
1352     def test_multi_file_force_py36(self) -> None:
1353         reg_mode = DEFAULT_MODE
1354         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1355         source, expected = read_data("force_py36")
1356         with cache_dir() as workspace:
1357             paths = [
1358                 (workspace / "file1.py").resolve(),
1359                 (workspace / "file2.py").resolve(),
1360             ]
1361             for path in paths:
1362                 with open(path, "w") as fh:
1363                     fh.write(source)
1364             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1365             for path in paths:
1366                 with open(path, "r") as fh:
1367                     actual = fh.read()
1368                 self.assertEqual(actual, expected)
1369             # verify cache with --target-version is separate
1370             pyi_cache = black.read_cache(py36_mode)
1371             normal_cache = black.read_cache(reg_mode)
1372             for path in paths:
1373                 self.assertIn(str(path), pyi_cache)
1374                 self.assertNotIn(str(path), normal_cache)
1375
1376     def test_pipe_force_py36(self) -> None:
1377         source, expected = read_data("force_py36")
1378         result = CliRunner().invoke(
1379             black.main,
1380             ["-", "-q", "--target-version=py36"],
1381             input=BytesIO(source.encode("utf8")),
1382         )
1383         self.assertEqual(result.exit_code, 0)
1384         actual = result.output
1385         self.assertFormatEqual(actual, expected)
1386
1387     def test_include_exclude(self) -> None:
1388         path = THIS_DIR / "data" / "include_exclude_tests"
1389         include = re.compile(r"\.pyi?$")
1390         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1391         report = black.Report()
1392         gitignore = PathSpec.from_lines("gitwildmatch", [])
1393         sources: List[Path] = []
1394         expected = [
1395             Path(path / "b/dont_exclude/a.py"),
1396             Path(path / "b/dont_exclude/a.pyi"),
1397         ]
1398         this_abs = THIS_DIR.resolve()
1399         sources.extend(
1400             black.gen_python_files(
1401                 path.iterdir(),
1402                 this_abs,
1403                 include,
1404                 exclude,
1405                 None,
1406                 None,
1407                 report,
1408                 gitignore,
1409             )
1410         )
1411         self.assertEqual(sorted(expected), sorted(sources))
1412
1413     def test_gitignore_used_as_default(self) -> None:
1414         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1415         include = re.compile(r"\.pyi?$")
1416         extend_exclude = re.compile(r"/exclude/")
1417         src = str(path / "b/")
1418         report = black.Report()
1419         expected: List[Path] = [
1420             path / "b/.definitely_exclude/a.py",
1421             path / "b/.definitely_exclude/a.pyi",
1422         ]
1423         sources = list(
1424             black.get_sources(
1425                 ctx=FakeContext(),
1426                 src=(src,),
1427                 quiet=True,
1428                 verbose=False,
1429                 include=include,
1430                 exclude=None,
1431                 extend_exclude=extend_exclude,
1432                 force_exclude=None,
1433                 report=report,
1434                 stdin_filename=None,
1435             )
1436         )
1437         self.assertEqual(sorted(expected), sorted(sources))
1438
1439     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1440     def test_exclude_for_issue_1572(self) -> None:
1441         # Exclude shouldn't touch files that were explicitly given to Black through the
1442         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1443         # https://github.com/psf/black/issues/1572
1444         path = THIS_DIR / "data" / "include_exclude_tests"
1445         include = ""
1446         exclude = r"/exclude/|a\.py"
1447         src = str(path / "b/exclude/a.py")
1448         report = black.Report()
1449         expected = [Path(path / "b/exclude/a.py")]
1450         sources = list(
1451             black.get_sources(
1452                 ctx=FakeContext(),
1453                 src=(src,),
1454                 quiet=True,
1455                 verbose=False,
1456                 include=re.compile(include),
1457                 exclude=re.compile(exclude),
1458                 extend_exclude=None,
1459                 force_exclude=None,
1460                 report=report,
1461                 stdin_filename=None,
1462             )
1463         )
1464         self.assertEqual(sorted(expected), sorted(sources))
1465
1466     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1467     def test_get_sources_with_stdin(self) -> None:
1468         include = ""
1469         exclude = r"/exclude/|a\.py"
1470         src = "-"
1471         report = black.Report()
1472         expected = [Path("-")]
1473         sources = list(
1474             black.get_sources(
1475                 ctx=FakeContext(),
1476                 src=(src,),
1477                 quiet=True,
1478                 verbose=False,
1479                 include=re.compile(include),
1480                 exclude=re.compile(exclude),
1481                 extend_exclude=None,
1482                 force_exclude=None,
1483                 report=report,
1484                 stdin_filename=None,
1485             )
1486         )
1487         self.assertEqual(sorted(expected), sorted(sources))
1488
1489     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1490     def test_get_sources_with_stdin_filename(self) -> None:
1491         include = ""
1492         exclude = r"/exclude/|a\.py"
1493         src = "-"
1494         report = black.Report()
1495         stdin_filename = str(THIS_DIR / "data/collections.py")
1496         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1497         sources = list(
1498             black.get_sources(
1499                 ctx=FakeContext(),
1500                 src=(src,),
1501                 quiet=True,
1502                 verbose=False,
1503                 include=re.compile(include),
1504                 exclude=re.compile(exclude),
1505                 extend_exclude=None,
1506                 force_exclude=None,
1507                 report=report,
1508                 stdin_filename=stdin_filename,
1509             )
1510         )
1511         self.assertEqual(sorted(expected), sorted(sources))
1512
1513     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1514     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1515         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1516         # file being passed directly. This is the same as
1517         # test_exclude_for_issue_1572
1518         path = THIS_DIR / "data" / "include_exclude_tests"
1519         include = ""
1520         exclude = r"/exclude/|a\.py"
1521         src = "-"
1522         report = black.Report()
1523         stdin_filename = str(path / "b/exclude/a.py")
1524         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1525         sources = list(
1526             black.get_sources(
1527                 ctx=FakeContext(),
1528                 src=(src,),
1529                 quiet=True,
1530                 verbose=False,
1531                 include=re.compile(include),
1532                 exclude=re.compile(exclude),
1533                 extend_exclude=None,
1534                 force_exclude=None,
1535                 report=report,
1536                 stdin_filename=stdin_filename,
1537             )
1538         )
1539         self.assertEqual(sorted(expected), sorted(sources))
1540
1541     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1542     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1543         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1544         # file being passed directly. This is the same as
1545         # test_exclude_for_issue_1572
1546         path = THIS_DIR / "data" / "include_exclude_tests"
1547         include = ""
1548         extend_exclude = r"/exclude/|a\.py"
1549         src = "-"
1550         report = black.Report()
1551         stdin_filename = str(path / "b/exclude/a.py")
1552         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1553         sources = list(
1554             black.get_sources(
1555                 ctx=FakeContext(),
1556                 src=(src,),
1557                 quiet=True,
1558                 verbose=False,
1559                 include=re.compile(include),
1560                 exclude=re.compile(""),
1561                 extend_exclude=re.compile(extend_exclude),
1562                 force_exclude=None,
1563                 report=report,
1564                 stdin_filename=stdin_filename,
1565             )
1566         )
1567         self.assertEqual(sorted(expected), sorted(sources))
1568
1569     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1570     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1571         # Force exclude should exclude the file when passing it through
1572         # stdin_filename
1573         path = THIS_DIR / "data" / "include_exclude_tests"
1574         include = ""
1575         force_exclude = r"/exclude/|a\.py"
1576         src = "-"
1577         report = black.Report()
1578         stdin_filename = str(path / "b/exclude/a.py")
1579         sources = list(
1580             black.get_sources(
1581                 ctx=FakeContext(),
1582                 src=(src,),
1583                 quiet=True,
1584                 verbose=False,
1585                 include=re.compile(include),
1586                 exclude=re.compile(""),
1587                 extend_exclude=None,
1588                 force_exclude=re.compile(force_exclude),
1589                 report=report,
1590                 stdin_filename=stdin_filename,
1591             )
1592         )
1593         self.assertEqual([], sorted(sources))
1594
1595     def test_reformat_one_with_stdin(self) -> None:
1596         with patch(
1597             "black.format_stdin_to_stdout",
1598             return_value=lambda *args, **kwargs: black.Changed.YES,
1599         ) as fsts:
1600             report = MagicMock()
1601             path = Path("-")
1602             black.reformat_one(
1603                 path,
1604                 fast=True,
1605                 write_back=black.WriteBack.YES,
1606                 mode=DEFAULT_MODE,
1607                 report=report,
1608             )
1609             fsts.assert_called_once()
1610             report.done.assert_called_with(path, black.Changed.YES)
1611
1612     def test_reformat_one_with_stdin_filename(self) -> None:
1613         with patch(
1614             "black.format_stdin_to_stdout",
1615             return_value=lambda *args, **kwargs: black.Changed.YES,
1616         ) as fsts:
1617             report = MagicMock()
1618             p = "foo.py"
1619             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1620             expected = Path(p)
1621             black.reformat_one(
1622                 path,
1623                 fast=True,
1624                 write_back=black.WriteBack.YES,
1625                 mode=DEFAULT_MODE,
1626                 report=report,
1627             )
1628             fsts.assert_called_once_with(
1629                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1630             )
1631             # __BLACK_STDIN_FILENAME__ should have been stripped
1632             report.done.assert_called_with(expected, black.Changed.YES)
1633
1634     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1635         with patch(
1636             "black.format_stdin_to_stdout",
1637             return_value=lambda *args, **kwargs: black.Changed.YES,
1638         ) as fsts:
1639             report = MagicMock()
1640             p = "foo.pyi"
1641             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1642             expected = Path(p)
1643             black.reformat_one(
1644                 path,
1645                 fast=True,
1646                 write_back=black.WriteBack.YES,
1647                 mode=DEFAULT_MODE,
1648                 report=report,
1649             )
1650             fsts.assert_called_once_with(
1651                 fast=True,
1652                 write_back=black.WriteBack.YES,
1653                 mode=replace(DEFAULT_MODE, is_pyi=True),
1654             )
1655             # __BLACK_STDIN_FILENAME__ should have been stripped
1656             report.done.assert_called_with(expected, black.Changed.YES)
1657
1658     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1659         with patch(
1660             "black.format_stdin_to_stdout",
1661             return_value=lambda *args, **kwargs: black.Changed.YES,
1662         ) as fsts:
1663             report = MagicMock()
1664             # Even with an existing file, since we are forcing stdin, black
1665             # should output to stdout and not modify the file inplace
1666             p = Path(str(THIS_DIR / "data/collections.py"))
1667             # Make sure is_file actually returns True
1668             self.assertTrue(p.is_file())
1669             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1670             expected = Path(p)
1671             black.reformat_one(
1672                 path,
1673                 fast=True,
1674                 write_back=black.WriteBack.YES,
1675                 mode=DEFAULT_MODE,
1676                 report=report,
1677             )
1678             fsts.assert_called_once()
1679             # __BLACK_STDIN_FILENAME__ should have been stripped
1680             report.done.assert_called_with(expected, black.Changed.YES)
1681
1682     def test_gitignore_exclude(self) -> None:
1683         path = THIS_DIR / "data" / "include_exclude_tests"
1684         include = re.compile(r"\.pyi?$")
1685         exclude = re.compile(r"")
1686         report = black.Report()
1687         gitignore = PathSpec.from_lines(
1688             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1689         )
1690         sources: List[Path] = []
1691         expected = [
1692             Path(path / "b/dont_exclude/a.py"),
1693             Path(path / "b/dont_exclude/a.pyi"),
1694         ]
1695         this_abs = THIS_DIR.resolve()
1696         sources.extend(
1697             black.gen_python_files(
1698                 path.iterdir(),
1699                 this_abs,
1700                 include,
1701                 exclude,
1702                 None,
1703                 None,
1704                 report,
1705                 gitignore,
1706             )
1707         )
1708         self.assertEqual(sorted(expected), sorted(sources))
1709
1710     def test_nested_gitignore(self) -> None:
1711         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1712         include = re.compile(r"\.pyi?$")
1713         exclude = re.compile(r"")
1714         root_gitignore = black.files.get_gitignore(path)
1715         report = black.Report()
1716         expected: List[Path] = [
1717             Path(path / "x.py"),
1718             Path(path / "root/b.py"),
1719             Path(path / "root/c.py"),
1720             Path(path / "root/child/c.py"),
1721         ]
1722         this_abs = THIS_DIR.resolve()
1723         sources = list(
1724             black.gen_python_files(
1725                 path.iterdir(),
1726                 this_abs,
1727                 include,
1728                 exclude,
1729                 None,
1730                 None,
1731                 report,
1732                 root_gitignore,
1733             )
1734         )
1735         self.assertEqual(sorted(expected), sorted(sources))
1736
1737     def test_empty_include(self) -> None:
1738         path = THIS_DIR / "data" / "include_exclude_tests"
1739         report = black.Report()
1740         gitignore = PathSpec.from_lines("gitwildmatch", [])
1741         empty = re.compile(r"")
1742         sources: List[Path] = []
1743         expected = [
1744             Path(path / "b/exclude/a.pie"),
1745             Path(path / "b/exclude/a.py"),
1746             Path(path / "b/exclude/a.pyi"),
1747             Path(path / "b/dont_exclude/a.pie"),
1748             Path(path / "b/dont_exclude/a.py"),
1749             Path(path / "b/dont_exclude/a.pyi"),
1750             Path(path / "b/.definitely_exclude/a.pie"),
1751             Path(path / "b/.definitely_exclude/a.py"),
1752             Path(path / "b/.definitely_exclude/a.pyi"),
1753             Path(path / ".gitignore"),
1754             Path(path / "pyproject.toml"),
1755         ]
1756         this_abs = THIS_DIR.resolve()
1757         sources.extend(
1758             black.gen_python_files(
1759                 path.iterdir(),
1760                 this_abs,
1761                 empty,
1762                 re.compile(black.DEFAULT_EXCLUDES),
1763                 None,
1764                 None,
1765                 report,
1766                 gitignore,
1767             )
1768         )
1769         self.assertEqual(sorted(expected), sorted(sources))
1770
1771     def test_extend_exclude(self) -> None:
1772         path = THIS_DIR / "data" / "include_exclude_tests"
1773         report = black.Report()
1774         gitignore = PathSpec.from_lines("gitwildmatch", [])
1775         sources: List[Path] = []
1776         expected = [
1777             Path(path / "b/exclude/a.py"),
1778             Path(path / "b/dont_exclude/a.py"),
1779         ]
1780         this_abs = THIS_DIR.resolve()
1781         sources.extend(
1782             black.gen_python_files(
1783                 path.iterdir(),
1784                 this_abs,
1785                 re.compile(black.DEFAULT_INCLUDES),
1786                 re.compile(r"\.pyi$"),
1787                 re.compile(r"\.definitely_exclude"),
1788                 None,
1789                 report,
1790                 gitignore,
1791             )
1792         )
1793         self.assertEqual(sorted(expected), sorted(sources))
1794
1795     def test_invalid_cli_regex(self) -> None:
1796         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1797             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1798
1799     def test_preserves_line_endings(self) -> None:
1800         with TemporaryDirectory() as workspace:
1801             test_file = Path(workspace) / "test.py"
1802             for nl in ["\n", "\r\n"]:
1803                 contents = nl.join(["def f(  ):", "    pass"])
1804                 test_file.write_bytes(contents.encode())
1805                 ff(test_file, write_back=black.WriteBack.YES)
1806                 updated_contents: bytes = test_file.read_bytes()
1807                 self.assertIn(nl.encode(), updated_contents)
1808                 if nl == "\n":
1809                     self.assertNotIn(b"\r\n", updated_contents)
1810
1811     def test_preserves_line_endings_via_stdin(self) -> None:
1812         for nl in ["\n", "\r\n"]:
1813             contents = nl.join(["def f(  ):", "    pass"])
1814             runner = BlackRunner()
1815             result = runner.invoke(
1816                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1817             )
1818             self.assertEqual(result.exit_code, 0)
1819             output = result.stdout_bytes
1820             self.assertIn(nl.encode("utf8"), output)
1821             if nl == "\n":
1822                 self.assertNotIn(b"\r\n", output)
1823
1824     def test_assert_equivalent_different_asts(self) -> None:
1825         with self.assertRaises(AssertionError):
1826             black.assert_equivalent("{}", "None")
1827
1828     def test_symlink_out_of_root_directory(self) -> None:
1829         path = MagicMock()
1830         root = THIS_DIR.resolve()
1831         child = MagicMock()
1832         include = re.compile(black.DEFAULT_INCLUDES)
1833         exclude = re.compile(black.DEFAULT_EXCLUDES)
1834         report = black.Report()
1835         gitignore = PathSpec.from_lines("gitwildmatch", [])
1836         # `child` should behave like a symlink which resolved path is clearly
1837         # outside of the `root` directory.
1838         path.iterdir.return_value = [child]
1839         child.resolve.return_value = Path("/a/b/c")
1840         child.as_posix.return_value = "/a/b/c"
1841         child.is_symlink.return_value = True
1842         try:
1843             list(
1844                 black.gen_python_files(
1845                     path.iterdir(),
1846                     root,
1847                     include,
1848                     exclude,
1849                     None,
1850                     None,
1851                     report,
1852                     gitignore,
1853                 )
1854             )
1855         except ValueError as ve:
1856             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1857         path.iterdir.assert_called_once()
1858         child.resolve.assert_called_once()
1859         child.is_symlink.assert_called_once()
1860         # `child` should behave like a strange file which resolved path is clearly
1861         # outside of the `root` directory.
1862         child.is_symlink.return_value = False
1863         with self.assertRaises(ValueError):
1864             list(
1865                 black.gen_python_files(
1866                     path.iterdir(),
1867                     root,
1868                     include,
1869                     exclude,
1870                     None,
1871                     None,
1872                     report,
1873                     gitignore,
1874                 )
1875             )
1876         path.iterdir.assert_called()
1877         self.assertEqual(path.iterdir.call_count, 2)
1878         child.resolve.assert_called()
1879         self.assertEqual(child.resolve.call_count, 2)
1880         child.is_symlink.assert_called()
1881         self.assertEqual(child.is_symlink.call_count, 2)
1882
1883     def test_shhh_click(self) -> None:
1884         try:
1885             from click import _unicodefun  # type: ignore
1886         except ModuleNotFoundError:
1887             self.skipTest("Incompatible Click version")
1888         if not hasattr(_unicodefun, "_verify_python3_env"):
1889             self.skipTest("Incompatible Click version")
1890         # First, let's see if Click is crashing with a preferred ASCII charset.
1891         with patch("locale.getpreferredencoding") as gpe:
1892             gpe.return_value = "ASCII"
1893             with self.assertRaises(RuntimeError):
1894                 _unicodefun._verify_python3_env()
1895         # Now, let's silence Click...
1896         black.patch_click()
1897         # ...and confirm it's silent.
1898         with patch("locale.getpreferredencoding") as gpe:
1899             gpe.return_value = "ASCII"
1900             try:
1901                 _unicodefun._verify_python3_env()
1902             except RuntimeError as re:
1903                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1904
1905     def test_root_logger_not_used_directly(self) -> None:
1906         def fail(*args: Any, **kwargs: Any) -> None:
1907             self.fail("Record created with root logger")
1908
1909         with patch.multiple(
1910             logging.root,
1911             debug=fail,
1912             info=fail,
1913             warning=fail,
1914             error=fail,
1915             critical=fail,
1916             log=fail,
1917         ):
1918             ff(THIS_DIR / "util.py")
1919
1920     def test_invalid_config_return_code(self) -> None:
1921         tmp_file = Path(black.dump_to_file())
1922         try:
1923             tmp_config = Path(black.dump_to_file())
1924             tmp_config.unlink()
1925             args = ["--config", str(tmp_config), str(tmp_file)]
1926             self.invokeBlack(args, exit_code=2, ignore_config=False)
1927         finally:
1928             tmp_file.unlink()
1929
1930     def test_parse_pyproject_toml(self) -> None:
1931         test_toml_file = THIS_DIR / "test.toml"
1932         config = black.parse_pyproject_toml(str(test_toml_file))
1933         self.assertEqual(config["verbose"], 1)
1934         self.assertEqual(config["check"], "no")
1935         self.assertEqual(config["diff"], "y")
1936         self.assertEqual(config["color"], True)
1937         self.assertEqual(config["line_length"], 79)
1938         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1939         self.assertEqual(config["exclude"], r"\.pyi?$")
1940         self.assertEqual(config["include"], r"\.py?$")
1941
1942     def test_read_pyproject_toml(self) -> None:
1943         test_toml_file = THIS_DIR / "test.toml"
1944         fake_ctx = FakeContext()
1945         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1946         config = fake_ctx.default_map
1947         self.assertEqual(config["verbose"], "1")
1948         self.assertEqual(config["check"], "no")
1949         self.assertEqual(config["diff"], "y")
1950         self.assertEqual(config["color"], "True")
1951         self.assertEqual(config["line_length"], "79")
1952         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1953         self.assertEqual(config["exclude"], r"\.pyi?$")
1954         self.assertEqual(config["include"], r"\.py?$")
1955
1956     def test_find_project_root(self) -> None:
1957         with TemporaryDirectory() as workspace:
1958             root = Path(workspace)
1959             test_dir = root / "test"
1960             test_dir.mkdir()
1961
1962             src_dir = root / "src"
1963             src_dir.mkdir()
1964
1965             root_pyproject = root / "pyproject.toml"
1966             root_pyproject.touch()
1967             src_pyproject = src_dir / "pyproject.toml"
1968             src_pyproject.touch()
1969             src_python = src_dir / "foo.py"
1970             src_python.touch()
1971
1972             self.assertEqual(
1973                 black.find_project_root((src_dir, test_dir)), root.resolve()
1974             )
1975             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1976             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1977
1978     @patch(
1979         "black.files.find_user_pyproject_toml",
1980         black.files.find_user_pyproject_toml.__wrapped__,
1981     )
1982     def test_find_user_pyproject_toml_linux(self) -> None:
1983         if system() == "Windows":
1984             return
1985
1986         # Test if XDG_CONFIG_HOME is checked
1987         with TemporaryDirectory() as workspace:
1988             tmp_user_config = Path(workspace) / "black"
1989             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1990                 self.assertEqual(
1991                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1992                 )
1993
1994         # Test fallback for XDG_CONFIG_HOME
1995         with patch.dict("os.environ"):
1996             os.environ.pop("XDG_CONFIG_HOME", None)
1997             fallback_user_config = Path("~/.config").expanduser() / "black"
1998             self.assertEqual(
1999                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
2000             )
2001
2002     def test_find_user_pyproject_toml_windows(self) -> None:
2003         if system() != "Windows":
2004             return
2005
2006         user_config_path = Path.home() / ".black"
2007         self.assertEqual(
2008             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2009         )
2010
2011     def test_bpo_33660_workaround(self) -> None:
2012         if system() == "Windows":
2013             return
2014
2015         # https://bugs.python.org/issue33660
2016
2017         old_cwd = Path.cwd()
2018         try:
2019             root = Path("/")
2020             os.chdir(str(root))
2021             path = Path("workspace") / "project"
2022             report = black.Report(verbose=True)
2023             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2024             self.assertEqual(normalized_path, "workspace/project")
2025         finally:
2026             os.chdir(str(old_cwd))
2027
2028     def test_newline_comment_interaction(self) -> None:
2029         source = "class A:\\\r\n# type: ignore\n pass\n"
2030         output = black.format_str(source, mode=DEFAULT_MODE)
2031         black.assert_stable(source, output, mode=DEFAULT_MODE)
2032
2033     def test_bpo_2142_workaround(self) -> None:
2034
2035         # https://bugs.python.org/issue2142
2036
2037         source, _ = read_data("missing_final_newline.py")
2038         # read_data adds a trailing newline
2039         source = source.rstrip()
2040         expected, _ = read_data("missing_final_newline.diff")
2041         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2042         diff_header = re.compile(
2043             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2044             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2045         )
2046         try:
2047             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2048             self.assertEqual(result.exit_code, 0)
2049         finally:
2050             os.unlink(tmp_file)
2051         actual = result.output
2052         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2053         self.assertEqual(actual, expected)
2054
2055     @pytest.mark.python2
2056     def test_docstring_reformat_for_py27(self) -> None:
2057         """
2058         Check that stripping trailing whitespace from Python 2 docstrings
2059         doesn't trigger a "not equivalent to source" error
2060         """
2061         source = (
2062             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2063         )
2064         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2065
2066         result = CliRunner().invoke(
2067             black.main,
2068             ["-", "-q", "--target-version=py27"],
2069             input=BytesIO(source),
2070         )
2071
2072         self.assertEqual(result.exit_code, 0)
2073         actual = result.output
2074         self.assertFormatEqual(actual, expected)
2075
2076     @staticmethod
2077     def compare_results(
2078         result: click.testing.Result, expected_value: str, expected_exit_code: int
2079     ) -> None:
2080         """Helper method to test the value and exit code of a click Result."""
2081         assert (
2082             result.output == expected_value
2083         ), "The output did not match the expected value."
2084         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2085
2086     def test_code_option(self) -> None:
2087         """Test the code option with no changes."""
2088         code = 'print("Hello world")\n'
2089         args = ["--code", code]
2090         result = CliRunner().invoke(black.main, args)
2091
2092         self.compare_results(result, code, 0)
2093
2094     def test_code_option_changed(self) -> None:
2095         """Test the code option when changes are required."""
2096         code = "print('hello world')"
2097         formatted = black.format_str(code, mode=DEFAULT_MODE)
2098
2099         args = ["--code", code]
2100         result = CliRunner().invoke(black.main, args)
2101
2102         self.compare_results(result, formatted, 0)
2103
2104     def test_code_option_check(self) -> None:
2105         """Test the code option when check is passed."""
2106         args = ["--check", "--code", 'print("Hello world")\n']
2107         result = CliRunner().invoke(black.main, args)
2108         self.compare_results(result, "", 0)
2109
2110     def test_code_option_check_changed(self) -> None:
2111         """Test the code option when changes are required, and check is passed."""
2112         args = ["--check", "--code", "print('hello world')"]
2113         result = CliRunner().invoke(black.main, args)
2114         self.compare_results(result, "", 1)
2115
2116     def test_code_option_diff(self) -> None:
2117         """Test the code option when diff is passed."""
2118         code = "print('hello world')"
2119         formatted = black.format_str(code, mode=DEFAULT_MODE)
2120         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2121
2122         args = ["--diff", "--code", code]
2123         result = CliRunner().invoke(black.main, args)
2124
2125         # Remove time from diff
2126         output = DIFF_TIME.sub("", result.output)
2127
2128         assert output == result_diff, "The output did not match the expected value."
2129         assert result.exit_code == 0, "The exit code is incorrect."
2130
2131     def test_code_option_color_diff(self) -> None:
2132         """Test the code option when color and diff are passed."""
2133         code = "print('hello world')"
2134         formatted = black.format_str(code, mode=DEFAULT_MODE)
2135
2136         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2137         result_diff = color_diff(result_diff)
2138
2139         args = ["--diff", "--color", "--code", code]
2140         result = CliRunner().invoke(black.main, args)
2141
2142         # Remove time from diff
2143         output = DIFF_TIME.sub("", result.output)
2144
2145         assert output == result_diff, "The output did not match the expected value."
2146         assert result.exit_code == 0, "The exit code is incorrect."
2147
2148     def test_code_option_safe(self) -> None:
2149         """Test that the code option throws an error when the sanity checks fail."""
2150         # Patch black.assert_equivalent to ensure the sanity checks fail
2151         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2152             code = 'print("Hello world")'
2153             error_msg = f"{code}\nerror: cannot format <string>: \n"
2154
2155             args = ["--safe", "--code", code]
2156             result = CliRunner().invoke(black.main, args)
2157
2158             self.compare_results(result, error_msg, 123)
2159
2160     def test_code_option_fast(self) -> None:
2161         """Test that the code option ignores errors when the sanity checks fail."""
2162         # Patch black.assert_equivalent to ensure the sanity checks fail
2163         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2164             code = 'print("Hello world")'
2165             formatted = black.format_str(code, mode=DEFAULT_MODE)
2166
2167             args = ["--fast", "--code", code]
2168             result = CliRunner().invoke(black.main, args)
2169
2170             self.compare_results(result, formatted, 0)
2171
2172     def test_code_option_config(self) -> None:
2173         """
2174         Test that the code option finds the pyproject.toml in the current directory.
2175         """
2176         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2177             # Make sure we are in the project root with the pyproject file
2178             if not Path("tests").exists():
2179                 os.chdir("..")
2180
2181             args = ["--code", "print"]
2182             CliRunner().invoke(black.main, args)
2183
2184             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2185             assert (
2186                 len(parse.mock_calls) >= 1
2187             ), "Expected config parse to be called with the current directory."
2188
2189             _, call_args, _ = parse.mock_calls[0]
2190             assert (
2191                 call_args[0].lower() == str(pyproject_path).lower()
2192             ), "Incorrect config loaded."
2193
2194     def test_code_option_parent_config(self) -> None:
2195         """
2196         Test that the code option finds the pyproject.toml in the parent directory.
2197         """
2198         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2199             # Make sure we are in the tests directory
2200             if Path("tests").exists():
2201                 os.chdir("tests")
2202
2203             args = ["--code", "print"]
2204             CliRunner().invoke(black.main, args)
2205
2206             pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2207             assert (
2208                 len(parse.mock_calls) >= 1
2209             ), "Expected config parse to be called with the current directory."
2210
2211             _, call_args, _ = parse.mock_calls[0]
2212             assert (
2213                 call_args[0].lower() == str(pyproject_path).lower()
2214             ), "Incorrect config loaded."
2215
2216
2217 with open(black.__file__, "r", encoding="utf-8") as _bf:
2218     black_source_lines = _bf.readlines()
2219
2220
2221 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2222     """Show function calls `from black/__init__.py` as they happen.
2223
2224     Register this with `sys.settrace()` in a test you're debugging.
2225     """
2226     if event != "call":
2227         return tracefunc
2228
2229     stack = len(inspect.stack()) - 19
2230     stack *= 2
2231     filename = frame.f_code.co_filename
2232     lineno = frame.f_lineno
2233     func_sig_lineno = lineno - 1
2234     funcname = black_source_lines[func_sig_lineno].strip()
2235     while funcname.startswith("@"):
2236         func_sig_lineno += 1
2237         funcname = black_source_lines[func_sig_lineno].strip()
2238     if "black/__init__.py" in filename:
2239         print(f"{' ' * stack}{lineno}:{funcname}")
2240     return tracefunc
2241
2242
2243 if __name__ == "__main__":
2244     unittest.main(module="test_black")