]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

blib2to3: support unparenthesized wulruses in more places (#2447)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 import io
10 from io import BytesIO
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     Callable,
21     Dict,
22     List,
23     Iterator,
24     TypeVar,
25 )
26 import pytest
27 import unittest
28 from unittest.mock import patch, MagicMock
29 from parameterized import parameterized
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37 from black.cache import get_cache_file
38 from black.debug import DebugVisitor
39 from black.output import diff, color_diff
40 from black.report import Report
41 import black.files
42
43 from pathspec import PathSpec
44
45 # Import other test classes
46 from tests.util import (
47     THIS_DIR,
48     change_directory,
49     read_data,
50     DETERMINISTIC_HEADER,
51     BlackBaseTestCase,
52     DEFAULT_MODE,
53     fs,
54     ff,
55     dump_to_stderr,
56 )
57
58
59 THIS_FILE = Path(__file__)
60 PY36_VERSIONS = {
61     TargetVersion.PY36,
62     TargetVersion.PY37,
63     TargetVersion.PY38,
64     TargetVersion.PY39,
65 }
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 T = TypeVar("T")
68 R = TypeVar("R")
69
70 # Match the time output in a diff, but nothing else
71 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
72
73
74 @contextmanager
75 def cache_dir(exists: bool = True) -> Iterator[Path]:
76     with TemporaryDirectory() as workspace:
77         cache_dir = Path(workspace)
78         if not exists:
79             cache_dir = cache_dir / "new"
80         with patch("black.cache.CACHE_DIR", cache_dir):
81             yield cache_dir
82
83
84 @contextmanager
85 def event_loop() -> Iterator[None]:
86     policy = asyncio.get_event_loop_policy()
87     loop = policy.new_event_loop()
88     asyncio.set_event_loop(loop)
89     try:
90         yield
91
92     finally:
93         loop.close()
94
95
96 class FakeContext(click.Context):
97     """A fake click Context for when calling functions that need it."""
98
99     def __init__(self) -> None:
100         self.default_map: Dict[str, Any] = {}
101
102
103 class FakeParameter(click.Parameter):
104     """A fake click Parameter for when calling functions that need it."""
105
106     def __init__(self) -> None:
107         pass
108
109
110 class BlackRunner(CliRunner):
111     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
112
113     def __init__(self) -> None:
114         super().__init__(mix_stderr=False)
115
116
117 class BlackTestCase(BlackBaseTestCase):
118     def invokeBlack(
119         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
120     ) -> None:
121         runner = BlackRunner()
122         if ignore_config:
123             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
124         result = runner.invoke(black.main, args)
125         assert result.stdout_bytes is not None
126         assert result.stderr_bytes is not None
127         self.assertEqual(
128             result.exit_code,
129             exit_code,
130             msg=(
131                 f"Failed with args: {args}\n"
132                 f"stdout: {result.stdout_bytes.decode()!r}\n"
133                 f"stderr: {result.stderr_bytes.decode()!r}\n"
134                 f"exception: {result.exception}"
135             ),
136         )
137
138     @patch("black.dump_to_file", dump_to_stderr)
139     def test_empty(self) -> None:
140         source = expected = ""
141         actual = fs(source)
142         self.assertFormatEqual(expected, actual)
143         black.assert_equivalent(source, actual)
144         black.assert_stable(source, actual, DEFAULT_MODE)
145
146     def test_empty_ff(self) -> None:
147         expected = ""
148         tmp_file = Path(black.dump_to_file())
149         try:
150             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
151             with open(tmp_file, encoding="utf8") as f:
152                 actual = f.read()
153         finally:
154             os.unlink(tmp_file)
155         self.assertFormatEqual(expected, actual)
156
157     def test_piping(self) -> None:
158         source, expected = read_data("src/black/__init__", data=False)
159         result = BlackRunner().invoke(
160             black.main,
161             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
162             input=BytesIO(source.encode("utf8")),
163         )
164         self.assertEqual(result.exit_code, 0)
165         self.assertFormatEqual(expected, result.output)
166         if source != result.output:
167             black.assert_equivalent(source, result.output)
168             black.assert_stable(source, result.output, DEFAULT_MODE)
169
170     def test_piping_diff(self) -> None:
171         diff_header = re.compile(
172             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
173             r"\+\d\d\d\d"
174         )
175         source, _ = read_data("expression.py")
176         expected, _ = read_data("expression.diff")
177         config = THIS_DIR / "data" / "empty_pyproject.toml"
178         args = [
179             "-",
180             "--fast",
181             f"--line-length={black.DEFAULT_LINE_LENGTH}",
182             "--diff",
183             f"--config={config}",
184         ]
185         result = BlackRunner().invoke(
186             black.main, args, input=BytesIO(source.encode("utf8"))
187         )
188         self.assertEqual(result.exit_code, 0)
189         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
190         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
191         self.assertEqual(expected, actual)
192
193     def test_piping_diff_with_color(self) -> None:
194         source, _ = read_data("expression.py")
195         config = THIS_DIR / "data" / "empty_pyproject.toml"
196         args = [
197             "-",
198             "--fast",
199             f"--line-length={black.DEFAULT_LINE_LENGTH}",
200             "--diff",
201             "--color",
202             f"--config={config}",
203         ]
204         result = BlackRunner().invoke(
205             black.main, args, input=BytesIO(source.encode("utf8"))
206         )
207         actual = result.output
208         # Again, the contents are checked in a different test, so only look for colors.
209         self.assertIn("\033[1;37m", actual)
210         self.assertIn("\033[36m", actual)
211         self.assertIn("\033[32m", actual)
212         self.assertIn("\033[31m", actual)
213         self.assertIn("\033[0m", actual)
214
215     @patch("black.dump_to_file", dump_to_stderr)
216     def _test_wip(self) -> None:
217         source, expected = read_data("wip")
218         sys.settrace(tracefunc)
219         mode = replace(
220             DEFAULT_MODE,
221             experimental_string_processing=False,
222             target_versions={black.TargetVersion.PY38},
223         )
224         actual = fs(source, mode=mode)
225         sys.settrace(None)
226         self.assertFormatEqual(expected, actual)
227         black.assert_equivalent(source, actual)
228         black.assert_stable(source, actual, black.FileMode())
229
230     @unittest.expectedFailure
231     @patch("black.dump_to_file", dump_to_stderr)
232     def test_trailing_comma_optional_parens_stability1(self) -> None:
233         source, _expected = read_data("trailing_comma_optional_parens1")
234         actual = fs(source)
235         black.assert_stable(source, actual, DEFAULT_MODE)
236
237     @unittest.expectedFailure
238     @patch("black.dump_to_file", dump_to_stderr)
239     def test_trailing_comma_optional_parens_stability2(self) -> None:
240         source, _expected = read_data("trailing_comma_optional_parens2")
241         actual = fs(source)
242         black.assert_stable(source, actual, DEFAULT_MODE)
243
244     @unittest.expectedFailure
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability3(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens3")
248         actual = fs(source)
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens1")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens2")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     @patch("black.dump_to_file", dump_to_stderr)
264     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
265         source, _expected = read_data("trailing_comma_optional_parens3")
266         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
267         black.assert_stable(source, actual, DEFAULT_MODE)
268
269     @patch("black.dump_to_file", dump_to_stderr)
270     def test_pep_572(self) -> None:
271         source, expected = read_data("pep_572")
272         actual = fs(source)
273         self.assertFormatEqual(expected, actual)
274         black.assert_stable(source, actual, DEFAULT_MODE)
275         if sys.version_info >= (3, 8):
276             black.assert_equivalent(source, actual)
277
278     @patch("black.dump_to_file", dump_to_stderr)
279     def test_pep_572_remove_parens(self) -> None:
280         source, expected = read_data("pep_572_remove_parens")
281         actual = fs(source)
282         self.assertFormatEqual(expected, actual)
283         black.assert_stable(source, actual, DEFAULT_MODE)
284         if sys.version_info >= (3, 8):
285             black.assert_equivalent(source, actual)
286
287     @patch("black.dump_to_file", dump_to_stderr)
288     def test_pep_572_do_not_remove_parens(self) -> None:
289         source, expected = read_data("pep_572_do_not_remove_parens")
290         # the AST safety checks will fail, but that's expected, just make sure no
291         # parentheses are touched
292         actual = black.format_str(source, mode=DEFAULT_MODE)
293         self.assertFormatEqual(expected, actual)
294
295     def test_pep_572_version_detection(self) -> None:
296         source, _ = read_data("pep_572")
297         root = black.lib2to3_parse(source)
298         features = black.get_features_used(root)
299         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
300         versions = black.detect_target_versions(root)
301         self.assertIn(black.TargetVersion.PY38, versions)
302
303     @parameterized.expand([(3, 9), (3, 10)])
304     def test_pep_572_newer_syntax(self, major: int, minor: int) -> None:
305         source, expected = read_data(f"pep_572_py{major}{minor}")
306         actual = fs(source, mode=DEFAULT_MODE)
307         self.assertFormatEqual(expected, actual)
308         if sys.version_info >= (major, minor):
309             black.assert_equivalent(source, actual)
310
311     def test_expression_ff(self) -> None:
312         source, expected = read_data("expression")
313         tmp_file = Path(black.dump_to_file(source))
314         try:
315             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
316             with open(tmp_file, encoding="utf8") as f:
317                 actual = f.read()
318         finally:
319             os.unlink(tmp_file)
320         self.assertFormatEqual(expected, actual)
321         with patch("black.dump_to_file", dump_to_stderr):
322             black.assert_equivalent(source, actual)
323             black.assert_stable(source, actual, DEFAULT_MODE)
324
325     def test_expression_diff(self) -> None:
326         source, _ = read_data("expression.py")
327         config = THIS_DIR / "data" / "empty_pyproject.toml"
328         expected, _ = read_data("expression.diff")
329         tmp_file = Path(black.dump_to_file(source))
330         diff_header = re.compile(
331             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
332             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
333         )
334         try:
335             result = BlackRunner().invoke(
336                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
337             )
338             self.assertEqual(result.exit_code, 0)
339         finally:
340             os.unlink(tmp_file)
341         actual = result.output
342         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
343         if expected != actual:
344             dump = black.dump_to_file(actual)
345             msg = (
346                 "Expected diff isn't equal to the actual. If you made changes to"
347                 " expression.py and this is an anticipated difference, overwrite"
348                 f" tests/data/expression.diff with {dump}"
349             )
350             self.assertEqual(expected, actual, msg)
351
352     def test_expression_diff_with_color(self) -> None:
353         source, _ = read_data("expression.py")
354         config = THIS_DIR / "data" / "empty_pyproject.toml"
355         expected, _ = read_data("expression.diff")
356         tmp_file = Path(black.dump_to_file(source))
357         try:
358             result = BlackRunner().invoke(
359                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
360             )
361         finally:
362             os.unlink(tmp_file)
363         actual = result.output
364         # We check the contents of the diff in `test_expression_diff`. All
365         # we need to check here is that color codes exist in the result.
366         self.assertIn("\033[1;37m", actual)
367         self.assertIn("\033[36m", actual)
368         self.assertIn("\033[32m", actual)
369         self.assertIn("\033[31m", actual)
370         self.assertIn("\033[0m", actual)
371
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_pep_570(self) -> None:
374         source, expected = read_data("pep_570")
375         actual = fs(source)
376         self.assertFormatEqual(expected, actual)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378         if sys.version_info >= (3, 8):
379             black.assert_equivalent(source, actual)
380
381     def test_detect_pos_only_arguments(self) -> None:
382         source, _ = read_data("pep_570")
383         root = black.lib2to3_parse(source)
384         features = black.get_features_used(root)
385         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
386         versions = black.detect_target_versions(root)
387         self.assertIn(black.TargetVersion.PY38, versions)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_string_quotes(self) -> None:
391         source, expected = read_data("string_quotes")
392         mode = black.Mode(experimental_string_processing=True)
393         actual = fs(source, mode=mode)
394         self.assertFormatEqual(expected, actual)
395         black.assert_equivalent(source, actual)
396         black.assert_stable(source, actual, mode)
397         mode = replace(mode, string_normalization=False)
398         not_normalized = fs(source, mode=mode)
399         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
400         black.assert_equivalent(source, not_normalized)
401         black.assert_stable(source, not_normalized, mode=mode)
402
403     @patch("black.dump_to_file", dump_to_stderr)
404     def test_docstring_no_string_normalization(self) -> None:
405         """Like test_docstring but with string normalization off."""
406         source, expected = read_data("docstring_no_string_normalization")
407         mode = replace(DEFAULT_MODE, string_normalization=False)
408         actual = fs(source, mode=mode)
409         self.assertFormatEqual(expected, actual)
410         black.assert_equivalent(source, actual)
411         black.assert_stable(source, actual, mode)
412
413     def test_long_strings_flag_disabled(self) -> None:
414         """Tests for turning off the string processing logic."""
415         source, expected = read_data("long_strings_flag_disabled")
416         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
417         actual = fs(source, mode=mode)
418         self.assertFormatEqual(expected, actual)
419         black.assert_stable(expected, actual, mode)
420
421     @patch("black.dump_to_file", dump_to_stderr)
422     def test_numeric_literals(self) -> None:
423         source, expected = read_data("numeric_literals")
424         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
425         actual = fs(source, mode=mode)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, mode)
429
430     @patch("black.dump_to_file", dump_to_stderr)
431     def test_numeric_literals_ignoring_underscores(self) -> None:
432         source, expected = read_data("numeric_literals_skip_underscores")
433         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
434         actual = fs(source, mode=mode)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, mode)
438
439     def test_skip_magic_trailing_comma(self) -> None:
440         source, _ = read_data("expression.py")
441         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
442         tmp_file = Path(black.dump_to_file(source))
443         diff_header = re.compile(
444             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
445             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
446         )
447         try:
448             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
449             self.assertEqual(result.exit_code, 0)
450         finally:
451             os.unlink(tmp_file)
452         actual = result.output
453         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
454         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
455         if expected != actual:
456             dump = black.dump_to_file(actual)
457             msg = (
458                 "Expected diff isn't equal to the actual. If you made changes to"
459                 " expression.py and this is an anticipated difference, overwrite"
460                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
461             )
462             self.assertEqual(expected, actual, msg)
463
464     @pytest.mark.python2
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_python2_print_function(self) -> None:
467         source, expected = read_data("python2_print_function")
468         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
469         actual = fs(source, mode=mode)
470         self.assertFormatEqual(expected, actual)
471         black.assert_equivalent(source, actual)
472         black.assert_stable(source, actual, mode)
473
474     @patch("black.dump_to_file", dump_to_stderr)
475     def test_stub(self) -> None:
476         mode = replace(DEFAULT_MODE, is_pyi=True)
477         source, expected = read_data("stub.pyi")
478         actual = fs(source, mode=mode)
479         self.assertFormatEqual(expected, actual)
480         black.assert_stable(source, actual, mode)
481
482     @patch("black.dump_to_file", dump_to_stderr)
483     def test_async_as_identifier(self) -> None:
484         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
485         source, expected = read_data("async_as_identifier")
486         actual = fs(source)
487         self.assertFormatEqual(expected, actual)
488         major, minor = sys.version_info[:2]
489         if major < 3 or (major <= 3 and minor < 7):
490             black.assert_equivalent(source, actual)
491         black.assert_stable(source, actual, DEFAULT_MODE)
492         # ensure black can parse this when the target is 3.6
493         self.invokeBlack([str(source_path), "--target-version", "py36"])
494         # but not on 3.7, because async/await is no longer an identifier
495         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_python37(self) -> None:
499         source_path = (THIS_DIR / "data" / "python37.py").resolve()
500         source, expected = read_data("python37")
501         actual = fs(source)
502         self.assertFormatEqual(expected, actual)
503         major, minor = sys.version_info[:2]
504         if major > 3 or (major == 3 and minor >= 7):
505             black.assert_equivalent(source, actual)
506         black.assert_stable(source, actual, DEFAULT_MODE)
507         # ensure black can parse this when the target is 3.7
508         self.invokeBlack([str(source_path), "--target-version", "py37"])
509         # but not on 3.6, because we use async as a reserved keyword
510         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
511
512     @patch("black.dump_to_file", dump_to_stderr)
513     def test_python38(self) -> None:
514         source, expected = read_data("python38")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         major, minor = sys.version_info[:2]
518         if major > 3 or (major == 3 and minor >= 8):
519             black.assert_equivalent(source, actual)
520         black.assert_stable(source, actual, DEFAULT_MODE)
521
522     @patch("black.dump_to_file", dump_to_stderr)
523     def test_python39(self) -> None:
524         source, expected = read_data("python39")
525         actual = fs(source)
526         self.assertFormatEqual(expected, actual)
527         major, minor = sys.version_info[:2]
528         if major > 3 or (major == 3 and minor >= 9):
529             black.assert_equivalent(source, actual)
530         black.assert_stable(source, actual, DEFAULT_MODE)
531
532     def test_tab_comment_indentation(self) -> None:
533         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
534         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
535         self.assertFormatEqual(contents_spc, fs(contents_spc))
536         self.assertFormatEqual(contents_spc, fs(contents_tab))
537
538         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
539         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
540         self.assertFormatEqual(contents_spc, fs(contents_spc))
541         self.assertFormatEqual(contents_spc, fs(contents_tab))
542
543         # mixed tabs and spaces (valid Python 2 code)
544         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
545         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
546         self.assertFormatEqual(contents_spc, fs(contents_spc))
547         self.assertFormatEqual(contents_spc, fs(contents_tab))
548
549         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
550         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
551         self.assertFormatEqual(contents_spc, fs(contents_spc))
552         self.assertFormatEqual(contents_spc, fs(contents_tab))
553
554     def test_report_verbose(self) -> None:
555         report = Report(verbose=True)
556         out_lines = []
557         err_lines = []
558
559         def out(msg: str, **kwargs: Any) -> None:
560             out_lines.append(msg)
561
562         def err(msg: str, **kwargs: Any) -> None:
563             err_lines.append(msg)
564
565         with patch("black.output._out", out), patch("black.output._err", err):
566             report.done(Path("f1"), black.Changed.NO)
567             self.assertEqual(len(out_lines), 1)
568             self.assertEqual(len(err_lines), 0)
569             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
570             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
571             self.assertEqual(report.return_code, 0)
572             report.done(Path("f2"), black.Changed.YES)
573             self.assertEqual(len(out_lines), 2)
574             self.assertEqual(len(err_lines), 0)
575             self.assertEqual(out_lines[-1], "reformatted f2")
576             self.assertEqual(
577                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
578             )
579             report.done(Path("f3"), black.Changed.CACHED)
580             self.assertEqual(len(out_lines), 3)
581             self.assertEqual(len(err_lines), 0)
582             self.assertEqual(
583                 out_lines[-1], "f3 wasn't modified on disk since last run."
584             )
585             self.assertEqual(
586                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
587             )
588             self.assertEqual(report.return_code, 0)
589             report.check = True
590             self.assertEqual(report.return_code, 1)
591             report.check = False
592             report.failed(Path("e1"), "boom")
593             self.assertEqual(len(out_lines), 3)
594             self.assertEqual(len(err_lines), 1)
595             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
596             self.assertEqual(
597                 unstyle(str(report)),
598                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
599                 " reformat.",
600             )
601             self.assertEqual(report.return_code, 123)
602             report.done(Path("f3"), black.Changed.YES)
603             self.assertEqual(len(out_lines), 4)
604             self.assertEqual(len(err_lines), 1)
605             self.assertEqual(out_lines[-1], "reformatted f3")
606             self.assertEqual(
607                 unstyle(str(report)),
608                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
609                 " reformat.",
610             )
611             self.assertEqual(report.return_code, 123)
612             report.failed(Path("e2"), "boom")
613             self.assertEqual(len(out_lines), 4)
614             self.assertEqual(len(err_lines), 2)
615             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
616             self.assertEqual(
617                 unstyle(str(report)),
618                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
619                 " reformat.",
620             )
621             self.assertEqual(report.return_code, 123)
622             report.path_ignored(Path("wat"), "no match")
623             self.assertEqual(len(out_lines), 5)
624             self.assertEqual(len(err_lines), 2)
625             self.assertEqual(out_lines[-1], "wat ignored: no match")
626             self.assertEqual(
627                 unstyle(str(report)),
628                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
629                 " reformat.",
630             )
631             self.assertEqual(report.return_code, 123)
632             report.done(Path("f4"), black.Changed.NO)
633             self.assertEqual(len(out_lines), 6)
634             self.assertEqual(len(err_lines), 2)
635             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
639                 " reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.check = True
643             self.assertEqual(
644                 unstyle(str(report)),
645                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
646                 " would fail to reformat.",
647             )
648             report.check = False
649             report.diff = True
650             self.assertEqual(
651                 unstyle(str(report)),
652                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
653                 " would fail to reformat.",
654             )
655
656     def test_report_quiet(self) -> None:
657         report = Report(quiet=True)
658         out_lines = []
659         err_lines = []
660
661         def out(msg: str, **kwargs: Any) -> None:
662             out_lines.append(msg)
663
664         def err(msg: str, **kwargs: Any) -> None:
665             err_lines.append(msg)
666
667         with patch("black.output._out", out), patch("black.output._err", err):
668             report.done(Path("f1"), black.Changed.NO)
669             self.assertEqual(len(out_lines), 0)
670             self.assertEqual(len(err_lines), 0)
671             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
672             self.assertEqual(report.return_code, 0)
673             report.done(Path("f2"), black.Changed.YES)
674             self.assertEqual(len(out_lines), 0)
675             self.assertEqual(len(err_lines), 0)
676             self.assertEqual(
677                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
678             )
679             report.done(Path("f3"), black.Changed.CACHED)
680             self.assertEqual(len(out_lines), 0)
681             self.assertEqual(len(err_lines), 0)
682             self.assertEqual(
683                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
684             )
685             self.assertEqual(report.return_code, 0)
686             report.check = True
687             self.assertEqual(report.return_code, 1)
688             report.check = False
689             report.failed(Path("e1"), "boom")
690             self.assertEqual(len(out_lines), 0)
691             self.assertEqual(len(err_lines), 1)
692             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
696                 " reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.done(Path("f3"), black.Changed.YES)
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 1)
702             self.assertEqual(
703                 unstyle(str(report)),
704                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
705                 " reformat.",
706             )
707             self.assertEqual(report.return_code, 123)
708             report.failed(Path("e2"), "boom")
709             self.assertEqual(len(out_lines), 0)
710             self.assertEqual(len(err_lines), 2)
711             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
715                 " reformat.",
716             )
717             self.assertEqual(report.return_code, 123)
718             report.path_ignored(Path("wat"), "no match")
719             self.assertEqual(len(out_lines), 0)
720             self.assertEqual(len(err_lines), 2)
721             self.assertEqual(
722                 unstyle(str(report)),
723                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
724                 " reformat.",
725             )
726             self.assertEqual(report.return_code, 123)
727             report.done(Path("f4"), black.Changed.NO)
728             self.assertEqual(len(out_lines), 0)
729             self.assertEqual(len(err_lines), 2)
730             self.assertEqual(
731                 unstyle(str(report)),
732                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
733                 " reformat.",
734             )
735             self.assertEqual(report.return_code, 123)
736             report.check = True
737             self.assertEqual(
738                 unstyle(str(report)),
739                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
740                 " would fail to reformat.",
741             )
742             report.check = False
743             report.diff = True
744             self.assertEqual(
745                 unstyle(str(report)),
746                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
747                 " would fail to reformat.",
748             )
749
750     def test_report_normal(self) -> None:
751         report = black.Report()
752         out_lines = []
753         err_lines = []
754
755         def out(msg: str, **kwargs: Any) -> None:
756             out_lines.append(msg)
757
758         def err(msg: str, **kwargs: Any) -> None:
759             err_lines.append(msg)
760
761         with patch("black.output._out", out), patch("black.output._err", err):
762             report.done(Path("f1"), black.Changed.NO)
763             self.assertEqual(len(out_lines), 0)
764             self.assertEqual(len(err_lines), 0)
765             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
766             self.assertEqual(report.return_code, 0)
767             report.done(Path("f2"), black.Changed.YES)
768             self.assertEqual(len(out_lines), 1)
769             self.assertEqual(len(err_lines), 0)
770             self.assertEqual(out_lines[-1], "reformatted f2")
771             self.assertEqual(
772                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
773             )
774             report.done(Path("f3"), black.Changed.CACHED)
775             self.assertEqual(len(out_lines), 1)
776             self.assertEqual(len(err_lines), 0)
777             self.assertEqual(out_lines[-1], "reformatted f2")
778             self.assertEqual(
779                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
780             )
781             self.assertEqual(report.return_code, 0)
782             report.check = True
783             self.assertEqual(report.return_code, 1)
784             report.check = False
785             report.failed(Path("e1"), "boom")
786             self.assertEqual(len(out_lines), 1)
787             self.assertEqual(len(err_lines), 1)
788             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
789             self.assertEqual(
790                 unstyle(str(report)),
791                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
792                 " reformat.",
793             )
794             self.assertEqual(report.return_code, 123)
795             report.done(Path("f3"), black.Changed.YES)
796             self.assertEqual(len(out_lines), 2)
797             self.assertEqual(len(err_lines), 1)
798             self.assertEqual(out_lines[-1], "reformatted f3")
799             self.assertEqual(
800                 unstyle(str(report)),
801                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
802                 " reformat.",
803             )
804             self.assertEqual(report.return_code, 123)
805             report.failed(Path("e2"), "boom")
806             self.assertEqual(len(out_lines), 2)
807             self.assertEqual(len(err_lines), 2)
808             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
812                 " reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.path_ignored(Path("wat"), "no match")
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
821                 " reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.done(Path("f4"), black.Changed.NO)
825             self.assertEqual(len(out_lines), 2)
826             self.assertEqual(len(err_lines), 2)
827             self.assertEqual(
828                 unstyle(str(report)),
829                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
830                 " reformat.",
831             )
832             self.assertEqual(report.return_code, 123)
833             report.check = True
834             self.assertEqual(
835                 unstyle(str(report)),
836                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
837                 " would fail to reformat.",
838             )
839             report.check = False
840             report.diff = True
841             self.assertEqual(
842                 unstyle(str(report)),
843                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
844                 " would fail to reformat.",
845             )
846
847     def test_lib2to3_parse(self) -> None:
848         with self.assertRaises(black.InvalidInput):
849             black.lib2to3_parse("invalid syntax")
850
851         straddling = "x + y"
852         black.lib2to3_parse(straddling)
853         black.lib2to3_parse(straddling, {TargetVersion.PY27})
854         black.lib2to3_parse(straddling, {TargetVersion.PY36})
855         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
856
857         py2_only = "print x"
858         black.lib2to3_parse(py2_only)
859         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
860         with self.assertRaises(black.InvalidInput):
861             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
862         with self.assertRaises(black.InvalidInput):
863             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
864
865         py3_only = "exec(x, end=y)"
866         black.lib2to3_parse(py3_only)
867         with self.assertRaises(black.InvalidInput):
868             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
869         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
870         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
871
872     def test_get_features_used_decorator(self) -> None:
873         # Test the feature detection of new decorator syntax
874         # since this makes some test cases of test_get_features_used()
875         # fails if it fails, this is tested first so that a useful case
876         # is identified
877         simples, relaxed = read_data("decorators")
878         # skip explanation comments at the top of the file
879         for simple_test in simples.split("##")[1:]:
880             node = black.lib2to3_parse(simple_test)
881             decorator = str(node.children[0].children[0]).strip()
882             self.assertNotIn(
883                 Feature.RELAXED_DECORATORS,
884                 black.get_features_used(node),
885                 msg=(
886                     f"decorator '{decorator}' follows python<=3.8 syntax"
887                     "but is detected as 3.9+"
888                     # f"The full node is\n{node!r}"
889                 ),
890             )
891         # skip the '# output' comment at the top of the output part
892         for relaxed_test in relaxed.split("##")[1:]:
893             node = black.lib2to3_parse(relaxed_test)
894             decorator = str(node.children[0].children[0]).strip()
895             self.assertIn(
896                 Feature.RELAXED_DECORATORS,
897                 black.get_features_used(node),
898                 msg=(
899                     f"decorator '{decorator}' uses python3.9+ syntax"
900                     "but is detected as python<=3.8"
901                     # f"The full node is\n{node!r}"
902                 ),
903             )
904
905     def test_get_features_used(self) -> None:
906         node = black.lib2to3_parse("def f(*, arg): ...\n")
907         self.assertEqual(black.get_features_used(node), set())
908         node = black.lib2to3_parse("def f(*, arg,): ...\n")
909         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
910         node = black.lib2to3_parse("f(*arg,)\n")
911         self.assertEqual(
912             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
913         )
914         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
915         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
916         node = black.lib2to3_parse("123_456\n")
917         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
918         node = black.lib2to3_parse("123456\n")
919         self.assertEqual(black.get_features_used(node), set())
920         source, expected = read_data("function")
921         node = black.lib2to3_parse(source)
922         expected_features = {
923             Feature.TRAILING_COMMA_IN_CALL,
924             Feature.TRAILING_COMMA_IN_DEF,
925             Feature.F_STRINGS,
926         }
927         self.assertEqual(black.get_features_used(node), expected_features)
928         node = black.lib2to3_parse(expected)
929         self.assertEqual(black.get_features_used(node), expected_features)
930         source, expected = read_data("expression")
931         node = black.lib2to3_parse(source)
932         self.assertEqual(black.get_features_used(node), set())
933         node = black.lib2to3_parse(expected)
934         self.assertEqual(black.get_features_used(node), set())
935
936     def test_get_future_imports(self) -> None:
937         node = black.lib2to3_parse("\n")
938         self.assertEqual(set(), black.get_future_imports(node))
939         node = black.lib2to3_parse("from __future__ import black\n")
940         self.assertEqual({"black"}, black.get_future_imports(node))
941         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
942         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
943         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
944         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
945         node = black.lib2to3_parse(
946             "from __future__ import multiple\nfrom __future__ import imports\n"
947         )
948         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
949         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
950         self.assertEqual({"black"}, black.get_future_imports(node))
951         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
952         self.assertEqual({"black"}, black.get_future_imports(node))
953         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
954         self.assertEqual(set(), black.get_future_imports(node))
955         node = black.lib2to3_parse("from some.module import black\n")
956         self.assertEqual(set(), black.get_future_imports(node))
957         node = black.lib2to3_parse(
958             "from __future__ import unicode_literals as _unicode_literals"
959         )
960         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
961         node = black.lib2to3_parse(
962             "from __future__ import unicode_literals as _lol, print"
963         )
964         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
965
966     def test_debug_visitor(self) -> None:
967         source, _ = read_data("debug_visitor.py")
968         expected, _ = read_data("debug_visitor.out")
969         out_lines = []
970         err_lines = []
971
972         def out(msg: str, **kwargs: Any) -> None:
973             out_lines.append(msg)
974
975         def err(msg: str, **kwargs: Any) -> None:
976             err_lines.append(msg)
977
978         with patch("black.debug.out", out):
979             DebugVisitor.show(source)
980         actual = "\n".join(out_lines) + "\n"
981         log_name = ""
982         if expected != actual:
983             log_name = black.dump_to_file(*out_lines)
984         self.assertEqual(
985             expected,
986             actual,
987             f"AST print out is different. Actual version dumped to {log_name}",
988         )
989
990     def test_format_file_contents(self) -> None:
991         empty = ""
992         mode = DEFAULT_MODE
993         with self.assertRaises(black.NothingChanged):
994             black.format_file_contents(empty, mode=mode, fast=False)
995         just_nl = "\n"
996         with self.assertRaises(black.NothingChanged):
997             black.format_file_contents(just_nl, mode=mode, fast=False)
998         same = "j = [1, 2, 3]\n"
999         with self.assertRaises(black.NothingChanged):
1000             black.format_file_contents(same, mode=mode, fast=False)
1001         different = "j = [1,2,3]"
1002         expected = same
1003         actual = black.format_file_contents(different, mode=mode, fast=False)
1004         self.assertEqual(expected, actual)
1005         invalid = "return if you can"
1006         with self.assertRaises(black.InvalidInput) as e:
1007             black.format_file_contents(invalid, mode=mode, fast=False)
1008         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1009
1010     def test_endmarker(self) -> None:
1011         n = black.lib2to3_parse("\n")
1012         self.assertEqual(n.type, black.syms.file_input)
1013         self.assertEqual(len(n.children), 1)
1014         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1015
1016     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1017     def test_assertFormatEqual(self) -> None:
1018         out_lines = []
1019         err_lines = []
1020
1021         def out(msg: str, **kwargs: Any) -> None:
1022             out_lines.append(msg)
1023
1024         def err(msg: str, **kwargs: Any) -> None:
1025             err_lines.append(msg)
1026
1027         with patch("black.output._out", out), patch("black.output._err", err):
1028             with self.assertRaises(AssertionError):
1029                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1030
1031         out_str = "".join(out_lines)
1032         self.assertTrue("Expected tree:" in out_str)
1033         self.assertTrue("Actual tree:" in out_str)
1034         self.assertEqual("".join(err_lines), "")
1035
1036     def test_cache_broken_file(self) -> None:
1037         mode = DEFAULT_MODE
1038         with cache_dir() as workspace:
1039             cache_file = get_cache_file(mode)
1040             with cache_file.open("w") as fobj:
1041                 fobj.write("this is not a pickle")
1042             self.assertEqual(black.read_cache(mode), {})
1043             src = (workspace / "test.py").resolve()
1044             with src.open("w") as fobj:
1045                 fobj.write("print('hello')")
1046             self.invokeBlack([str(src)])
1047             cache = black.read_cache(mode)
1048             self.assertIn(str(src), cache)
1049
1050     def test_cache_single_file_already_cached(self) -> None:
1051         mode = DEFAULT_MODE
1052         with cache_dir() as workspace:
1053             src = (workspace / "test.py").resolve()
1054             with src.open("w") as fobj:
1055                 fobj.write("print('hello')")
1056             black.write_cache({}, [src], mode)
1057             self.invokeBlack([str(src)])
1058             with src.open("r") as fobj:
1059                 self.assertEqual(fobj.read(), "print('hello')")
1060
1061     @event_loop()
1062     def test_cache_multiple_files(self) -> None:
1063         mode = DEFAULT_MODE
1064         with cache_dir() as workspace, patch(
1065             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1066         ):
1067             one = (workspace / "one.py").resolve()
1068             with one.open("w") as fobj:
1069                 fobj.write("print('hello')")
1070             two = (workspace / "two.py").resolve()
1071             with two.open("w") as fobj:
1072                 fobj.write("print('hello')")
1073             black.write_cache({}, [one], mode)
1074             self.invokeBlack([str(workspace)])
1075             with one.open("r") as fobj:
1076                 self.assertEqual(fobj.read(), "print('hello')")
1077             with two.open("r") as fobj:
1078                 self.assertEqual(fobj.read(), 'print("hello")\n')
1079             cache = black.read_cache(mode)
1080             self.assertIn(str(one), cache)
1081             self.assertIn(str(two), cache)
1082
1083     def test_no_cache_when_writeback_diff(self) -> None:
1084         mode = DEFAULT_MODE
1085         with cache_dir() as workspace:
1086             src = (workspace / "test.py").resolve()
1087             with src.open("w") as fobj:
1088                 fobj.write("print('hello')")
1089             with patch("black.read_cache") as read_cache, patch(
1090                 "black.write_cache"
1091             ) as write_cache:
1092                 self.invokeBlack([str(src), "--diff"])
1093                 cache_file = get_cache_file(mode)
1094                 self.assertFalse(cache_file.exists())
1095                 write_cache.assert_not_called()
1096                 read_cache.assert_not_called()
1097
1098     def test_no_cache_when_writeback_color_diff(self) -> None:
1099         mode = DEFAULT_MODE
1100         with cache_dir() as workspace:
1101             src = (workspace / "test.py").resolve()
1102             with src.open("w") as fobj:
1103                 fobj.write("print('hello')")
1104             with patch("black.read_cache") as read_cache, patch(
1105                 "black.write_cache"
1106             ) as write_cache:
1107                 self.invokeBlack([str(src), "--diff", "--color"])
1108                 cache_file = get_cache_file(mode)
1109                 self.assertFalse(cache_file.exists())
1110                 write_cache.assert_not_called()
1111                 read_cache.assert_not_called()
1112
1113     @event_loop()
1114     def test_output_locking_when_writeback_diff(self) -> None:
1115         with cache_dir() as workspace:
1116             for tag in range(0, 4):
1117                 src = (workspace / f"test{tag}.py").resolve()
1118                 with src.open("w") as fobj:
1119                     fobj.write("print('hello')")
1120             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1121                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1122                 # this isn't quite doing what we want, but if it _isn't_
1123                 # called then we cannot be using the lock it provides
1124                 mgr.assert_called()
1125
1126     @event_loop()
1127     def test_output_locking_when_writeback_color_diff(self) -> None:
1128         with cache_dir() as workspace:
1129             for tag in range(0, 4):
1130                 src = (workspace / f"test{tag}.py").resolve()
1131                 with src.open("w") as fobj:
1132                     fobj.write("print('hello')")
1133             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1134                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1135                 # this isn't quite doing what we want, but if it _isn't_
1136                 # called then we cannot be using the lock it provides
1137                 mgr.assert_called()
1138
1139     def test_no_cache_when_stdin(self) -> None:
1140         mode = DEFAULT_MODE
1141         with cache_dir():
1142             result = CliRunner().invoke(
1143                 black.main, ["-"], input=BytesIO(b"print('hello')")
1144             )
1145             self.assertEqual(result.exit_code, 0)
1146             cache_file = get_cache_file(mode)
1147             self.assertFalse(cache_file.exists())
1148
1149     def test_read_cache_no_cachefile(self) -> None:
1150         mode = DEFAULT_MODE
1151         with cache_dir():
1152             self.assertEqual(black.read_cache(mode), {})
1153
1154     def test_write_cache_read_cache(self) -> None:
1155         mode = DEFAULT_MODE
1156         with cache_dir() as workspace:
1157             src = (workspace / "test.py").resolve()
1158             src.touch()
1159             black.write_cache({}, [src], mode)
1160             cache = black.read_cache(mode)
1161             self.assertIn(str(src), cache)
1162             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1163
1164     def test_filter_cached(self) -> None:
1165         with TemporaryDirectory() as workspace:
1166             path = Path(workspace)
1167             uncached = (path / "uncached").resolve()
1168             cached = (path / "cached").resolve()
1169             cached_but_changed = (path / "changed").resolve()
1170             uncached.touch()
1171             cached.touch()
1172             cached_but_changed.touch()
1173             cache = {
1174                 str(cached): black.get_cache_info(cached),
1175                 str(cached_but_changed): (0.0, 0),
1176             }
1177             todo, done = black.filter_cached(
1178                 cache, {uncached, cached, cached_but_changed}
1179             )
1180             self.assertEqual(todo, {uncached, cached_but_changed})
1181             self.assertEqual(done, {cached})
1182
1183     def test_write_cache_creates_directory_if_needed(self) -> None:
1184         mode = DEFAULT_MODE
1185         with cache_dir(exists=False) as workspace:
1186             self.assertFalse(workspace.exists())
1187             black.write_cache({}, [], mode)
1188             self.assertTrue(workspace.exists())
1189
1190     @event_loop()
1191     def test_failed_formatting_does_not_get_cached(self) -> None:
1192         mode = DEFAULT_MODE
1193         with cache_dir() as workspace, patch(
1194             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1195         ):
1196             failing = (workspace / "failing.py").resolve()
1197             with failing.open("w") as fobj:
1198                 fobj.write("not actually python")
1199             clean = (workspace / "clean.py").resolve()
1200             with clean.open("w") as fobj:
1201                 fobj.write('print("hello")\n')
1202             self.invokeBlack([str(workspace)], exit_code=123)
1203             cache = black.read_cache(mode)
1204             self.assertNotIn(str(failing), cache)
1205             self.assertIn(str(clean), cache)
1206
1207     def test_write_cache_write_fail(self) -> None:
1208         mode = DEFAULT_MODE
1209         with cache_dir(), patch.object(Path, "open") as mock:
1210             mock.side_effect = OSError
1211             black.write_cache({}, [], mode)
1212
1213     @event_loop()
1214     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1215     def test_works_in_mono_process_only_environment(self) -> None:
1216         with cache_dir() as workspace:
1217             for f in [
1218                 (workspace / "one.py").resolve(),
1219                 (workspace / "two.py").resolve(),
1220             ]:
1221                 f.write_text('print("hello")\n')
1222             self.invokeBlack([str(workspace)])
1223
1224     @event_loop()
1225     def test_check_diff_use_together(self) -> None:
1226         with cache_dir():
1227             # Files which will be reformatted.
1228             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1229             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1230             # Files which will not be reformatted.
1231             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1232             self.invokeBlack([str(src2), "--diff", "--check"])
1233             # Multi file command.
1234             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1235
1236     def test_no_files(self) -> None:
1237         with cache_dir():
1238             # Without an argument, black exits with error code 0.
1239             self.invokeBlack([])
1240
1241     def test_broken_symlink(self) -> None:
1242         with cache_dir() as workspace:
1243             symlink = workspace / "broken_link.py"
1244             try:
1245                 symlink.symlink_to("nonexistent.py")
1246             except OSError as e:
1247                 self.skipTest(f"Can't create symlinks: {e}")
1248             self.invokeBlack([str(workspace.resolve())])
1249
1250     def test_read_cache_line_lengths(self) -> None:
1251         mode = DEFAULT_MODE
1252         short_mode = replace(DEFAULT_MODE, line_length=1)
1253         with cache_dir() as workspace:
1254             path = (workspace / "file.py").resolve()
1255             path.touch()
1256             black.write_cache({}, [path], mode)
1257             one = black.read_cache(mode)
1258             self.assertIn(str(path), one)
1259             two = black.read_cache(short_mode)
1260             self.assertNotIn(str(path), two)
1261
1262     def test_single_file_force_pyi(self) -> None:
1263         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1264         contents, expected = read_data("force_pyi")
1265         with cache_dir() as workspace:
1266             path = (workspace / "file.py").resolve()
1267             with open(path, "w") as fh:
1268                 fh.write(contents)
1269             self.invokeBlack([str(path), "--pyi"])
1270             with open(path, "r") as fh:
1271                 actual = fh.read()
1272             # verify cache with --pyi is separate
1273             pyi_cache = black.read_cache(pyi_mode)
1274             self.assertIn(str(path), pyi_cache)
1275             normal_cache = black.read_cache(DEFAULT_MODE)
1276             self.assertNotIn(str(path), normal_cache)
1277         self.assertFormatEqual(expected, actual)
1278         black.assert_equivalent(contents, actual)
1279         black.assert_stable(contents, actual, pyi_mode)
1280
1281     @event_loop()
1282     def test_multi_file_force_pyi(self) -> None:
1283         reg_mode = DEFAULT_MODE
1284         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1285         contents, expected = read_data("force_pyi")
1286         with cache_dir() as workspace:
1287             paths = [
1288                 (workspace / "file1.py").resolve(),
1289                 (workspace / "file2.py").resolve(),
1290             ]
1291             for path in paths:
1292                 with open(path, "w") as fh:
1293                     fh.write(contents)
1294             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1295             for path in paths:
1296                 with open(path, "r") as fh:
1297                     actual = fh.read()
1298                 self.assertEqual(actual, expected)
1299             # verify cache with --pyi is separate
1300             pyi_cache = black.read_cache(pyi_mode)
1301             normal_cache = black.read_cache(reg_mode)
1302             for path in paths:
1303                 self.assertIn(str(path), pyi_cache)
1304                 self.assertNotIn(str(path), normal_cache)
1305
1306     def test_pipe_force_pyi(self) -> None:
1307         source, expected = read_data("force_pyi")
1308         result = CliRunner().invoke(
1309             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1310         )
1311         self.assertEqual(result.exit_code, 0)
1312         actual = result.output
1313         self.assertFormatEqual(actual, expected)
1314
1315     def test_single_file_force_py36(self) -> None:
1316         reg_mode = DEFAULT_MODE
1317         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1318         source, expected = read_data("force_py36")
1319         with cache_dir() as workspace:
1320             path = (workspace / "file.py").resolve()
1321             with open(path, "w") as fh:
1322                 fh.write(source)
1323             self.invokeBlack([str(path), *PY36_ARGS])
1324             with open(path, "r") as fh:
1325                 actual = fh.read()
1326             # verify cache with --target-version is separate
1327             py36_cache = black.read_cache(py36_mode)
1328             self.assertIn(str(path), py36_cache)
1329             normal_cache = black.read_cache(reg_mode)
1330             self.assertNotIn(str(path), normal_cache)
1331         self.assertEqual(actual, expected)
1332
1333     @event_loop()
1334     def test_multi_file_force_py36(self) -> None:
1335         reg_mode = DEFAULT_MODE
1336         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1337         source, expected = read_data("force_py36")
1338         with cache_dir() as workspace:
1339             paths = [
1340                 (workspace / "file1.py").resolve(),
1341                 (workspace / "file2.py").resolve(),
1342             ]
1343             for path in paths:
1344                 with open(path, "w") as fh:
1345                     fh.write(source)
1346             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1347             for path in paths:
1348                 with open(path, "r") as fh:
1349                     actual = fh.read()
1350                 self.assertEqual(actual, expected)
1351             # verify cache with --target-version is separate
1352             pyi_cache = black.read_cache(py36_mode)
1353             normal_cache = black.read_cache(reg_mode)
1354             for path in paths:
1355                 self.assertIn(str(path), pyi_cache)
1356                 self.assertNotIn(str(path), normal_cache)
1357
1358     def test_pipe_force_py36(self) -> None:
1359         source, expected = read_data("force_py36")
1360         result = CliRunner().invoke(
1361             black.main,
1362             ["-", "-q", "--target-version=py36"],
1363             input=BytesIO(source.encode("utf8")),
1364         )
1365         self.assertEqual(result.exit_code, 0)
1366         actual = result.output
1367         self.assertFormatEqual(actual, expected)
1368
1369     def test_include_exclude(self) -> None:
1370         path = THIS_DIR / "data" / "include_exclude_tests"
1371         include = re.compile(r"\.pyi?$")
1372         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1373         report = black.Report()
1374         gitignore = PathSpec.from_lines("gitwildmatch", [])
1375         sources: List[Path] = []
1376         expected = [
1377             Path(path / "b/dont_exclude/a.py"),
1378             Path(path / "b/dont_exclude/a.pyi"),
1379         ]
1380         this_abs = THIS_DIR.resolve()
1381         sources.extend(
1382             black.gen_python_files(
1383                 path.iterdir(),
1384                 this_abs,
1385                 include,
1386                 exclude,
1387                 None,
1388                 None,
1389                 report,
1390                 gitignore,
1391                 verbose=False,
1392                 quiet=False,
1393             )
1394         )
1395         self.assertEqual(sorted(expected), sorted(sources))
1396
1397     def test_gitignore_used_as_default(self) -> None:
1398         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1399         include = re.compile(r"\.pyi?$")
1400         extend_exclude = re.compile(r"/exclude/")
1401         src = str(path / "b/")
1402         report = black.Report()
1403         expected: List[Path] = [
1404             path / "b/.definitely_exclude/a.py",
1405             path / "b/.definitely_exclude/a.pyi",
1406         ]
1407         sources = list(
1408             black.get_sources(
1409                 ctx=FakeContext(),
1410                 src=(src,),
1411                 quiet=True,
1412                 verbose=False,
1413                 include=include,
1414                 exclude=None,
1415                 extend_exclude=extend_exclude,
1416                 force_exclude=None,
1417                 report=report,
1418                 stdin_filename=None,
1419             )
1420         )
1421         self.assertEqual(sorted(expected), sorted(sources))
1422
1423     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1424     def test_exclude_for_issue_1572(self) -> None:
1425         # Exclude shouldn't touch files that were explicitly given to Black through the
1426         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1427         # https://github.com/psf/black/issues/1572
1428         path = THIS_DIR / "data" / "include_exclude_tests"
1429         include = ""
1430         exclude = r"/exclude/|a\.py"
1431         src = str(path / "b/exclude/a.py")
1432         report = black.Report()
1433         expected = [Path(path / "b/exclude/a.py")]
1434         sources = list(
1435             black.get_sources(
1436                 ctx=FakeContext(),
1437                 src=(src,),
1438                 quiet=True,
1439                 verbose=False,
1440                 include=re.compile(include),
1441                 exclude=re.compile(exclude),
1442                 extend_exclude=None,
1443                 force_exclude=None,
1444                 report=report,
1445                 stdin_filename=None,
1446             )
1447         )
1448         self.assertEqual(sorted(expected), sorted(sources))
1449
1450     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1451     def test_get_sources_with_stdin(self) -> None:
1452         include = ""
1453         exclude = r"/exclude/|a\.py"
1454         src = "-"
1455         report = black.Report()
1456         expected = [Path("-")]
1457         sources = list(
1458             black.get_sources(
1459                 ctx=FakeContext(),
1460                 src=(src,),
1461                 quiet=True,
1462                 verbose=False,
1463                 include=re.compile(include),
1464                 exclude=re.compile(exclude),
1465                 extend_exclude=None,
1466                 force_exclude=None,
1467                 report=report,
1468                 stdin_filename=None,
1469             )
1470         )
1471         self.assertEqual(sorted(expected), sorted(sources))
1472
1473     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1474     def test_get_sources_with_stdin_filename(self) -> None:
1475         include = ""
1476         exclude = r"/exclude/|a\.py"
1477         src = "-"
1478         report = black.Report()
1479         stdin_filename = str(THIS_DIR / "data/collections.py")
1480         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1481         sources = list(
1482             black.get_sources(
1483                 ctx=FakeContext(),
1484                 src=(src,),
1485                 quiet=True,
1486                 verbose=False,
1487                 include=re.compile(include),
1488                 exclude=re.compile(exclude),
1489                 extend_exclude=None,
1490                 force_exclude=None,
1491                 report=report,
1492                 stdin_filename=stdin_filename,
1493             )
1494         )
1495         self.assertEqual(sorted(expected), sorted(sources))
1496
1497     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1498     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1499         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1500         # file being passed directly. This is the same as
1501         # test_exclude_for_issue_1572
1502         path = THIS_DIR / "data" / "include_exclude_tests"
1503         include = ""
1504         exclude = r"/exclude/|a\.py"
1505         src = "-"
1506         report = black.Report()
1507         stdin_filename = str(path / "b/exclude/a.py")
1508         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1509         sources = list(
1510             black.get_sources(
1511                 ctx=FakeContext(),
1512                 src=(src,),
1513                 quiet=True,
1514                 verbose=False,
1515                 include=re.compile(include),
1516                 exclude=re.compile(exclude),
1517                 extend_exclude=None,
1518                 force_exclude=None,
1519                 report=report,
1520                 stdin_filename=stdin_filename,
1521             )
1522         )
1523         self.assertEqual(sorted(expected), sorted(sources))
1524
1525     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1526     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1527         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1528         # file being passed directly. This is the same as
1529         # test_exclude_for_issue_1572
1530         path = THIS_DIR / "data" / "include_exclude_tests"
1531         include = ""
1532         extend_exclude = r"/exclude/|a\.py"
1533         src = "-"
1534         report = black.Report()
1535         stdin_filename = str(path / "b/exclude/a.py")
1536         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1537         sources = list(
1538             black.get_sources(
1539                 ctx=FakeContext(),
1540                 src=(src,),
1541                 quiet=True,
1542                 verbose=False,
1543                 include=re.compile(include),
1544                 exclude=re.compile(""),
1545                 extend_exclude=re.compile(extend_exclude),
1546                 force_exclude=None,
1547                 report=report,
1548                 stdin_filename=stdin_filename,
1549             )
1550         )
1551         self.assertEqual(sorted(expected), sorted(sources))
1552
1553     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1554     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1555         # Force exclude should exclude the file when passing it through
1556         # stdin_filename
1557         path = THIS_DIR / "data" / "include_exclude_tests"
1558         include = ""
1559         force_exclude = r"/exclude/|a\.py"
1560         src = "-"
1561         report = black.Report()
1562         stdin_filename = str(path / "b/exclude/a.py")
1563         sources = list(
1564             black.get_sources(
1565                 ctx=FakeContext(),
1566                 src=(src,),
1567                 quiet=True,
1568                 verbose=False,
1569                 include=re.compile(include),
1570                 exclude=re.compile(""),
1571                 extend_exclude=None,
1572                 force_exclude=re.compile(force_exclude),
1573                 report=report,
1574                 stdin_filename=stdin_filename,
1575             )
1576         )
1577         self.assertEqual([], sorted(sources))
1578
1579     def test_reformat_one_with_stdin(self) -> None:
1580         with patch(
1581             "black.format_stdin_to_stdout",
1582             return_value=lambda *args, **kwargs: black.Changed.YES,
1583         ) as fsts:
1584             report = MagicMock()
1585             path = Path("-")
1586             black.reformat_one(
1587                 path,
1588                 fast=True,
1589                 write_back=black.WriteBack.YES,
1590                 mode=DEFAULT_MODE,
1591                 report=report,
1592             )
1593             fsts.assert_called_once()
1594             report.done.assert_called_with(path, black.Changed.YES)
1595
1596     def test_reformat_one_with_stdin_filename(self) -> None:
1597         with patch(
1598             "black.format_stdin_to_stdout",
1599             return_value=lambda *args, **kwargs: black.Changed.YES,
1600         ) as fsts:
1601             report = MagicMock()
1602             p = "foo.py"
1603             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1604             expected = Path(p)
1605             black.reformat_one(
1606                 path,
1607                 fast=True,
1608                 write_back=black.WriteBack.YES,
1609                 mode=DEFAULT_MODE,
1610                 report=report,
1611             )
1612             fsts.assert_called_once_with(
1613                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1614             )
1615             # __BLACK_STDIN_FILENAME__ should have been stripped
1616             report.done.assert_called_with(expected, black.Changed.YES)
1617
1618     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1619         with patch(
1620             "black.format_stdin_to_stdout",
1621             return_value=lambda *args, **kwargs: black.Changed.YES,
1622         ) as fsts:
1623             report = MagicMock()
1624             p = "foo.pyi"
1625             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1626             expected = Path(p)
1627             black.reformat_one(
1628                 path,
1629                 fast=True,
1630                 write_back=black.WriteBack.YES,
1631                 mode=DEFAULT_MODE,
1632                 report=report,
1633             )
1634             fsts.assert_called_once_with(
1635                 fast=True,
1636                 write_back=black.WriteBack.YES,
1637                 mode=replace(DEFAULT_MODE, is_pyi=True),
1638             )
1639             # __BLACK_STDIN_FILENAME__ should have been stripped
1640             report.done.assert_called_with(expected, black.Changed.YES)
1641
1642     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1643         with patch(
1644             "black.format_stdin_to_stdout",
1645             return_value=lambda *args, **kwargs: black.Changed.YES,
1646         ) as fsts:
1647             report = MagicMock()
1648             # Even with an existing file, since we are forcing stdin, black
1649             # should output to stdout and not modify the file inplace
1650             p = Path(str(THIS_DIR / "data/collections.py"))
1651             # Make sure is_file actually returns True
1652             self.assertTrue(p.is_file())
1653             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1654             expected = Path(p)
1655             black.reformat_one(
1656                 path,
1657                 fast=True,
1658                 write_back=black.WriteBack.YES,
1659                 mode=DEFAULT_MODE,
1660                 report=report,
1661             )
1662             fsts.assert_called_once()
1663             # __BLACK_STDIN_FILENAME__ should have been stripped
1664             report.done.assert_called_with(expected, black.Changed.YES)
1665
1666     def test_reformat_one_with_stdin_empty(self) -> None:
1667         output = io.StringIO()
1668         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1669             try:
1670                 black.format_stdin_to_stdout(
1671                     fast=True,
1672                     content="",
1673                     write_back=black.WriteBack.YES,
1674                     mode=DEFAULT_MODE,
1675                 )
1676             except io.UnsupportedOperation:
1677                 pass  # StringIO does not support detach
1678             assert output.getvalue() == ""
1679
1680     def test_gitignore_exclude(self) -> None:
1681         path = THIS_DIR / "data" / "include_exclude_tests"
1682         include = re.compile(r"\.pyi?$")
1683         exclude = re.compile(r"")
1684         report = black.Report()
1685         gitignore = PathSpec.from_lines(
1686             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1687         )
1688         sources: List[Path] = []
1689         expected = [
1690             Path(path / "b/dont_exclude/a.py"),
1691             Path(path / "b/dont_exclude/a.pyi"),
1692         ]
1693         this_abs = THIS_DIR.resolve()
1694         sources.extend(
1695             black.gen_python_files(
1696                 path.iterdir(),
1697                 this_abs,
1698                 include,
1699                 exclude,
1700                 None,
1701                 None,
1702                 report,
1703                 gitignore,
1704                 verbose=False,
1705                 quiet=False,
1706             )
1707         )
1708         self.assertEqual(sorted(expected), sorted(sources))
1709
1710     def test_nested_gitignore(self) -> None:
1711         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1712         include = re.compile(r"\.pyi?$")
1713         exclude = re.compile(r"")
1714         root_gitignore = black.files.get_gitignore(path)
1715         report = black.Report()
1716         expected: List[Path] = [
1717             Path(path / "x.py"),
1718             Path(path / "root/b.py"),
1719             Path(path / "root/c.py"),
1720             Path(path / "root/child/c.py"),
1721         ]
1722         this_abs = THIS_DIR.resolve()
1723         sources = list(
1724             black.gen_python_files(
1725                 path.iterdir(),
1726                 this_abs,
1727                 include,
1728                 exclude,
1729                 None,
1730                 None,
1731                 report,
1732                 root_gitignore,
1733                 verbose=False,
1734                 quiet=False,
1735             )
1736         )
1737         self.assertEqual(sorted(expected), sorted(sources))
1738
1739     def test_invalid_gitignore(self) -> None:
1740         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1741         empty_config = path / "pyproject.toml"
1742         result = BlackRunner().invoke(
1743             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1744         )
1745         assert result.exit_code == 1
1746         assert result.stderr_bytes is not None
1747
1748         gitignore = path / ".gitignore"
1749         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1750
1751     def test_invalid_nested_gitignore(self) -> None:
1752         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1753         empty_config = path / "pyproject.toml"
1754         result = BlackRunner().invoke(
1755             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1756         )
1757         assert result.exit_code == 1
1758         assert result.stderr_bytes is not None
1759
1760         gitignore = path / "a" / ".gitignore"
1761         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1762
1763     def test_empty_include(self) -> None:
1764         path = THIS_DIR / "data" / "include_exclude_tests"
1765         report = black.Report()
1766         gitignore = PathSpec.from_lines("gitwildmatch", [])
1767         empty = re.compile(r"")
1768         sources: List[Path] = []
1769         expected = [
1770             Path(path / "b/exclude/a.pie"),
1771             Path(path / "b/exclude/a.py"),
1772             Path(path / "b/exclude/a.pyi"),
1773             Path(path / "b/dont_exclude/a.pie"),
1774             Path(path / "b/dont_exclude/a.py"),
1775             Path(path / "b/dont_exclude/a.pyi"),
1776             Path(path / "b/.definitely_exclude/a.pie"),
1777             Path(path / "b/.definitely_exclude/a.py"),
1778             Path(path / "b/.definitely_exclude/a.pyi"),
1779             Path(path / ".gitignore"),
1780             Path(path / "pyproject.toml"),
1781         ]
1782         this_abs = THIS_DIR.resolve()
1783         sources.extend(
1784             black.gen_python_files(
1785                 path.iterdir(),
1786                 this_abs,
1787                 empty,
1788                 re.compile(black.DEFAULT_EXCLUDES),
1789                 None,
1790                 None,
1791                 report,
1792                 gitignore,
1793                 verbose=False,
1794                 quiet=False,
1795             )
1796         )
1797         self.assertEqual(sorted(expected), sorted(sources))
1798
1799     def test_extend_exclude(self) -> None:
1800         path = THIS_DIR / "data" / "include_exclude_tests"
1801         report = black.Report()
1802         gitignore = PathSpec.from_lines("gitwildmatch", [])
1803         sources: List[Path] = []
1804         expected = [
1805             Path(path / "b/exclude/a.py"),
1806             Path(path / "b/dont_exclude/a.py"),
1807         ]
1808         this_abs = THIS_DIR.resolve()
1809         sources.extend(
1810             black.gen_python_files(
1811                 path.iterdir(),
1812                 this_abs,
1813                 re.compile(black.DEFAULT_INCLUDES),
1814                 re.compile(r"\.pyi$"),
1815                 re.compile(r"\.definitely_exclude"),
1816                 None,
1817                 report,
1818                 gitignore,
1819                 verbose=False,
1820                 quiet=False,
1821             )
1822         )
1823         self.assertEqual(sorted(expected), sorted(sources))
1824
1825     def test_invalid_cli_regex(self) -> None:
1826         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1827             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1828
1829     def test_required_version_matches_version(self) -> None:
1830         self.invokeBlack(
1831             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1832         )
1833
1834     def test_required_version_does_not_match_version(self) -> None:
1835         self.invokeBlack(
1836             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1837         )
1838
1839     def test_preserves_line_endings(self) -> None:
1840         with TemporaryDirectory() as workspace:
1841             test_file = Path(workspace) / "test.py"
1842             for nl in ["\n", "\r\n"]:
1843                 contents = nl.join(["def f(  ):", "    pass"])
1844                 test_file.write_bytes(contents.encode())
1845                 ff(test_file, write_back=black.WriteBack.YES)
1846                 updated_contents: bytes = test_file.read_bytes()
1847                 self.assertIn(nl.encode(), updated_contents)
1848                 if nl == "\n":
1849                     self.assertNotIn(b"\r\n", updated_contents)
1850
1851     def test_preserves_line_endings_via_stdin(self) -> None:
1852         for nl in ["\n", "\r\n"]:
1853             contents = nl.join(["def f(  ):", "    pass"])
1854             runner = BlackRunner()
1855             result = runner.invoke(
1856                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1857             )
1858             self.assertEqual(result.exit_code, 0)
1859             output = result.stdout_bytes
1860             self.assertIn(nl.encode("utf8"), output)
1861             if nl == "\n":
1862                 self.assertNotIn(b"\r\n", output)
1863
1864     def test_assert_equivalent_different_asts(self) -> None:
1865         with self.assertRaises(AssertionError):
1866             black.assert_equivalent("{}", "None")
1867
1868     def test_symlink_out_of_root_directory(self) -> None:
1869         path = MagicMock()
1870         root = THIS_DIR.resolve()
1871         child = MagicMock()
1872         include = re.compile(black.DEFAULT_INCLUDES)
1873         exclude = re.compile(black.DEFAULT_EXCLUDES)
1874         report = black.Report()
1875         gitignore = PathSpec.from_lines("gitwildmatch", [])
1876         # `child` should behave like a symlink which resolved path is clearly
1877         # outside of the `root` directory.
1878         path.iterdir.return_value = [child]
1879         child.resolve.return_value = Path("/a/b/c")
1880         child.as_posix.return_value = "/a/b/c"
1881         child.is_symlink.return_value = True
1882         try:
1883             list(
1884                 black.gen_python_files(
1885                     path.iterdir(),
1886                     root,
1887                     include,
1888                     exclude,
1889                     None,
1890                     None,
1891                     report,
1892                     gitignore,
1893                     verbose=False,
1894                     quiet=False,
1895                 )
1896             )
1897         except ValueError as ve:
1898             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1899         path.iterdir.assert_called_once()
1900         child.resolve.assert_called_once()
1901         child.is_symlink.assert_called_once()
1902         # `child` should behave like a strange file which resolved path is clearly
1903         # outside of the `root` directory.
1904         child.is_symlink.return_value = False
1905         with self.assertRaises(ValueError):
1906             list(
1907                 black.gen_python_files(
1908                     path.iterdir(),
1909                     root,
1910                     include,
1911                     exclude,
1912                     None,
1913                     None,
1914                     report,
1915                     gitignore,
1916                     verbose=False,
1917                     quiet=False,
1918                 )
1919             )
1920         path.iterdir.assert_called()
1921         self.assertEqual(path.iterdir.call_count, 2)
1922         child.resolve.assert_called()
1923         self.assertEqual(child.resolve.call_count, 2)
1924         child.is_symlink.assert_called()
1925         self.assertEqual(child.is_symlink.call_count, 2)
1926
1927     def test_shhh_click(self) -> None:
1928         try:
1929             from click import _unicodefun
1930         except ModuleNotFoundError:
1931             self.skipTest("Incompatible Click version")
1932         if not hasattr(_unicodefun, "_verify_python3_env"):
1933             self.skipTest("Incompatible Click version")
1934         # First, let's see if Click is crashing with a preferred ASCII charset.
1935         with patch("locale.getpreferredencoding") as gpe:
1936             gpe.return_value = "ASCII"
1937             with self.assertRaises(RuntimeError):
1938                 _unicodefun._verify_python3_env()  # type: ignore
1939         # Now, let's silence Click...
1940         black.patch_click()
1941         # ...and confirm it's silent.
1942         with patch("locale.getpreferredencoding") as gpe:
1943             gpe.return_value = "ASCII"
1944             try:
1945                 _unicodefun._verify_python3_env()  # type: ignore
1946             except RuntimeError as re:
1947                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1948
1949     def test_root_logger_not_used_directly(self) -> None:
1950         def fail(*args: Any, **kwargs: Any) -> None:
1951             self.fail("Record created with root logger")
1952
1953         with patch.multiple(
1954             logging.root,
1955             debug=fail,
1956             info=fail,
1957             warning=fail,
1958             error=fail,
1959             critical=fail,
1960             log=fail,
1961         ):
1962             ff(THIS_DIR / "util.py")
1963
1964     def test_invalid_config_return_code(self) -> None:
1965         tmp_file = Path(black.dump_to_file())
1966         try:
1967             tmp_config = Path(black.dump_to_file())
1968             tmp_config.unlink()
1969             args = ["--config", str(tmp_config), str(tmp_file)]
1970             self.invokeBlack(args, exit_code=2, ignore_config=False)
1971         finally:
1972             tmp_file.unlink()
1973
1974     def test_parse_pyproject_toml(self) -> None:
1975         test_toml_file = THIS_DIR / "test.toml"
1976         config = black.parse_pyproject_toml(str(test_toml_file))
1977         self.assertEqual(config["verbose"], 1)
1978         self.assertEqual(config["check"], "no")
1979         self.assertEqual(config["diff"], "y")
1980         self.assertEqual(config["color"], True)
1981         self.assertEqual(config["line_length"], 79)
1982         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1983         self.assertEqual(config["exclude"], r"\.pyi?$")
1984         self.assertEqual(config["include"], r"\.py?$")
1985
1986     def test_read_pyproject_toml(self) -> None:
1987         test_toml_file = THIS_DIR / "test.toml"
1988         fake_ctx = FakeContext()
1989         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1990         config = fake_ctx.default_map
1991         self.assertEqual(config["verbose"], "1")
1992         self.assertEqual(config["check"], "no")
1993         self.assertEqual(config["diff"], "y")
1994         self.assertEqual(config["color"], "True")
1995         self.assertEqual(config["line_length"], "79")
1996         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1997         self.assertEqual(config["exclude"], r"\.pyi?$")
1998         self.assertEqual(config["include"], r"\.py?$")
1999
2000     def test_find_project_root(self) -> None:
2001         with TemporaryDirectory() as workspace:
2002             root = Path(workspace)
2003             test_dir = root / "test"
2004             test_dir.mkdir()
2005
2006             src_dir = root / "src"
2007             src_dir.mkdir()
2008
2009             root_pyproject = root / "pyproject.toml"
2010             root_pyproject.touch()
2011             src_pyproject = src_dir / "pyproject.toml"
2012             src_pyproject.touch()
2013             src_python = src_dir / "foo.py"
2014             src_python.touch()
2015
2016             self.assertEqual(
2017                 black.find_project_root((src_dir, test_dir)), root.resolve()
2018             )
2019             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
2020             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
2021
2022     @patch(
2023         "black.files.find_user_pyproject_toml",
2024         black.files.find_user_pyproject_toml.__wrapped__,
2025     )
2026     def test_find_user_pyproject_toml_linux(self) -> None:
2027         if system() == "Windows":
2028             return
2029
2030         # Test if XDG_CONFIG_HOME is checked
2031         with TemporaryDirectory() as workspace:
2032             tmp_user_config = Path(workspace) / "black"
2033             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
2034                 self.assertEqual(
2035                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
2036                 )
2037
2038         # Test fallback for XDG_CONFIG_HOME
2039         with patch.dict("os.environ"):
2040             os.environ.pop("XDG_CONFIG_HOME", None)
2041             fallback_user_config = Path("~/.config").expanduser() / "black"
2042             self.assertEqual(
2043                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
2044             )
2045
2046     def test_find_user_pyproject_toml_windows(self) -> None:
2047         if system() != "Windows":
2048             return
2049
2050         user_config_path = Path.home() / ".black"
2051         self.assertEqual(
2052             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2053         )
2054
2055     def test_bpo_33660_workaround(self) -> None:
2056         if system() == "Windows":
2057             return
2058
2059         # https://bugs.python.org/issue33660
2060         root = Path("/")
2061         with change_directory(root):
2062             path = Path("workspace") / "project"
2063             report = black.Report(verbose=True)
2064             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2065             self.assertEqual(normalized_path, "workspace/project")
2066
2067     def test_newline_comment_interaction(self) -> None:
2068         source = "class A:\\\r\n# type: ignore\n pass\n"
2069         output = black.format_str(source, mode=DEFAULT_MODE)
2070         black.assert_stable(source, output, mode=DEFAULT_MODE)
2071
2072     def test_bpo_2142_workaround(self) -> None:
2073
2074         # https://bugs.python.org/issue2142
2075
2076         source, _ = read_data("missing_final_newline.py")
2077         # read_data adds a trailing newline
2078         source = source.rstrip()
2079         expected, _ = read_data("missing_final_newline.diff")
2080         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2081         diff_header = re.compile(
2082             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2083             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2084         )
2085         try:
2086             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2087             self.assertEqual(result.exit_code, 0)
2088         finally:
2089             os.unlink(tmp_file)
2090         actual = result.output
2091         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2092         self.assertEqual(actual, expected)
2093
2094     @pytest.mark.python2
2095     def test_docstring_reformat_for_py27(self) -> None:
2096         """
2097         Check that stripping trailing whitespace from Python 2 docstrings
2098         doesn't trigger a "not equivalent to source" error
2099         """
2100         source = (
2101             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2102         )
2103         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2104
2105         result = CliRunner().invoke(
2106             black.main,
2107             ["-", "-q", "--target-version=py27"],
2108             input=BytesIO(source),
2109         )
2110
2111         self.assertEqual(result.exit_code, 0)
2112         actual = result.output
2113         self.assertFormatEqual(actual, expected)
2114
2115     @staticmethod
2116     def compare_results(
2117         result: click.testing.Result, expected_value: str, expected_exit_code: int
2118     ) -> None:
2119         """Helper method to test the value and exit code of a click Result."""
2120         assert (
2121             result.output == expected_value
2122         ), "The output did not match the expected value."
2123         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2124
2125     def test_code_option(self) -> None:
2126         """Test the code option with no changes."""
2127         code = 'print("Hello world")\n'
2128         args = ["--code", code]
2129         result = CliRunner().invoke(black.main, args)
2130
2131         self.compare_results(result, code, 0)
2132
2133     def test_code_option_changed(self) -> None:
2134         """Test the code option when changes are required."""
2135         code = "print('hello world')"
2136         formatted = black.format_str(code, mode=DEFAULT_MODE)
2137
2138         args = ["--code", code]
2139         result = CliRunner().invoke(black.main, args)
2140
2141         self.compare_results(result, formatted, 0)
2142
2143     def test_code_option_check(self) -> None:
2144         """Test the code option when check is passed."""
2145         args = ["--check", "--code", 'print("Hello world")\n']
2146         result = CliRunner().invoke(black.main, args)
2147         self.compare_results(result, "", 0)
2148
2149     def test_code_option_check_changed(self) -> None:
2150         """Test the code option when changes are required, and check is passed."""
2151         args = ["--check", "--code", "print('hello world')"]
2152         result = CliRunner().invoke(black.main, args)
2153         self.compare_results(result, "", 1)
2154
2155     def test_code_option_diff(self) -> None:
2156         """Test the code option when diff is passed."""
2157         code = "print('hello world')"
2158         formatted = black.format_str(code, mode=DEFAULT_MODE)
2159         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2160
2161         args = ["--diff", "--code", code]
2162         result = CliRunner().invoke(black.main, args)
2163
2164         # Remove time from diff
2165         output = DIFF_TIME.sub("", result.output)
2166
2167         assert output == result_diff, "The output did not match the expected value."
2168         assert result.exit_code == 0, "The exit code is incorrect."
2169
2170     def test_code_option_color_diff(self) -> None:
2171         """Test the code option when color and diff are passed."""
2172         code = "print('hello world')"
2173         formatted = black.format_str(code, mode=DEFAULT_MODE)
2174
2175         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2176         result_diff = color_diff(result_diff)
2177
2178         args = ["--diff", "--color", "--code", code]
2179         result = CliRunner().invoke(black.main, args)
2180
2181         # Remove time from diff
2182         output = DIFF_TIME.sub("", result.output)
2183
2184         assert output == result_diff, "The output did not match the expected value."
2185         assert result.exit_code == 0, "The exit code is incorrect."
2186
2187     def test_code_option_safe(self) -> None:
2188         """Test that the code option throws an error when the sanity checks fail."""
2189         # Patch black.assert_equivalent to ensure the sanity checks fail
2190         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2191             code = 'print("Hello world")'
2192             error_msg = f"{code}\nerror: cannot format <string>: \n"
2193
2194             args = ["--safe", "--code", code]
2195             result = CliRunner().invoke(black.main, args)
2196
2197             self.compare_results(result, error_msg, 123)
2198
2199     def test_code_option_fast(self) -> None:
2200         """Test that the code option ignores errors when the sanity checks fail."""
2201         # Patch black.assert_equivalent to ensure the sanity checks fail
2202         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2203             code = 'print("Hello world")'
2204             formatted = black.format_str(code, mode=DEFAULT_MODE)
2205
2206             args = ["--fast", "--code", code]
2207             result = CliRunner().invoke(black.main, args)
2208
2209             self.compare_results(result, formatted, 0)
2210
2211     def test_code_option_config(self) -> None:
2212         """
2213         Test that the code option finds the pyproject.toml in the current directory.
2214         """
2215         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2216             args = ["--code", "print"]
2217             CliRunner().invoke(black.main, args)
2218
2219             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2220             assert (
2221                 len(parse.mock_calls) >= 1
2222             ), "Expected config parse to be called with the current directory."
2223
2224             _, call_args, _ = parse.mock_calls[0]
2225             assert (
2226                 call_args[0].lower() == str(pyproject_path).lower()
2227             ), "Incorrect config loaded."
2228
2229     def test_code_option_parent_config(self) -> None:
2230         """
2231         Test that the code option finds the pyproject.toml in the parent directory.
2232         """
2233         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2234             with change_directory(Path("tests")):
2235                 args = ["--code", "print"]
2236                 CliRunner().invoke(black.main, args)
2237
2238                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2239                 assert (
2240                     len(parse.mock_calls) >= 1
2241                 ), "Expected config parse to be called with the current directory."
2242
2243                 _, call_args, _ = parse.mock_calls[0]
2244                 assert (
2245                     call_args[0].lower() == str(pyproject_path).lower()
2246                 ), "Incorrect config loaded."
2247
2248
2249 with open(black.__file__, "r", encoding="utf-8") as _bf:
2250     black_source_lines = _bf.readlines()
2251
2252
2253 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2254     """Show function calls `from black/__init__.py` as they happen.
2255
2256     Register this with `sys.settrace()` in a test you're debugging.
2257     """
2258     if event != "call":
2259         return tracefunc
2260
2261     stack = len(inspect.stack()) - 19
2262     stack *= 2
2263     filename = frame.f_code.co_filename
2264     lineno = frame.f_lineno
2265     func_sig_lineno = lineno - 1
2266     funcname = black_source_lines[func_sig_lineno].strip()
2267     while funcname.startswith("@"):
2268         func_sig_lineno += 1
2269         funcname = black_source_lines[func_sig_lineno].strip()
2270     if "black/__init__.py" in filename:
2271         print(f"{' ' * stack}{lineno}:{funcname}")
2272     return tracefunc
2273
2274
2275 if __name__ == "__main__":
2276     unittest.main(module="test_black")