]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Maintainers += Richard Si (aka ichard26) (#2149)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 from pathspec import PathSpec
38
39 # Import other test classes
40 from tests.util import (
41     THIS_DIR,
42     read_data,
43     DETERMINISTIC_HEADER,
44     BlackBaseTestCase,
45     DEFAULT_MODE,
46     fs,
47     ff,
48     dump_to_stderr,
49 )
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 THIS_FILE = Path(__file__)
54 PY36_VERSIONS = {
55     TargetVersion.PY36,
56     TargetVersion.PY37,
57     TargetVersion.PY38,
58     TargetVersion.PY39,
59 }
60 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
61 T = TypeVar("T")
62 R = TypeVar("R")
63
64
65 @contextmanager
66 def cache_dir(exists: bool = True) -> Iterator[Path]:
67     with TemporaryDirectory() as workspace:
68         cache_dir = Path(workspace)
69         if not exists:
70             cache_dir = cache_dir / "new"
71         with patch("black.CACHE_DIR", cache_dir):
72             yield cache_dir
73
74
75 @contextmanager
76 def event_loop() -> Iterator[None]:
77     policy = asyncio.get_event_loop_policy()
78     loop = policy.new_event_loop()
79     asyncio.set_event_loop(loop)
80     try:
81         yield
82
83     finally:
84         loop.close()
85
86
87 class FakeContext(click.Context):
88     """A fake click Context for when calling functions that need it."""
89
90     def __init__(self) -> None:
91         self.default_map: Dict[str, Any] = {}
92
93
94 class FakeParameter(click.Parameter):
95     """A fake click Parameter for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         pass
99
100
101 class BlackRunner(CliRunner):
102     """Modify CliRunner so that stderr is not merged with stdout.
103
104     This is a hack that can be removed once we depend on Click 7.x"""
105
106     def __init__(self) -> None:
107         self.stderrbuf = BytesIO()
108         self.stdoutbuf = BytesIO()
109         self.stdout_bytes = b""
110         self.stderr_bytes = b""
111         super().__init__()
112
113     @contextmanager
114     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
115         with super().isolation(*args, **kwargs) as output:
116             try:
117                 hold_stderr = sys.stderr
118                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
119                 yield output
120             finally:
121                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
122                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
123                 sys.stderr = hold_stderr
124
125
126 class BlackTestCase(BlackBaseTestCase):
127     def invokeBlack(
128         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
129     ) -> None:
130         runner = BlackRunner()
131         if ignore_config:
132             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
133         result = runner.invoke(black.main, args)
134         self.assertEqual(
135             result.exit_code,
136             exit_code,
137             msg=(
138                 f"Failed with args: {args}\n"
139                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
140                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
141                 f"exception: {result.exception}"
142             ),
143         )
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_empty(self) -> None:
147         source = expected = ""
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, DEFAULT_MODE)
152
153     def test_empty_ff(self) -> None:
154         expected = ""
155         tmp_file = Path(black.dump_to_file())
156         try:
157             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
158             with open(tmp_file, encoding="utf8") as f:
159                 actual = f.read()
160         finally:
161             os.unlink(tmp_file)
162         self.assertFormatEqual(expected, actual)
163
164     def test_piping(self) -> None:
165         source, expected = read_data("src/black/__init__", data=False)
166         result = BlackRunner().invoke(
167             black.main,
168             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         black.assert_equivalent(source, result.output)
174         black.assert_stable(source, result.output, DEFAULT_MODE)
175
176     def test_piping_diff(self) -> None:
177         diff_header = re.compile(
178             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
179             r"\+\d\d\d\d"
180         )
181         source, _ = read_data("expression.py")
182         expected, _ = read_data("expression.diff")
183         config = THIS_DIR / "data" / "empty_pyproject.toml"
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={config}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         config = THIS_DIR / "data" / "empty_pyproject.toml"
202         args = [
203             "-",
204             "--fast",
205             f"--line-length={black.DEFAULT_LINE_LENGTH}",
206             "--diff",
207             "--color",
208             f"--config={config}",
209         ]
210         result = BlackRunner().invoke(
211             black.main, args, input=BytesIO(source.encode("utf8"))
212         )
213         actual = result.output
214         # Again, the contents are checked in a different test, so only look for colors.
215         self.assertIn("\033[1;37m", actual)
216         self.assertIn("\033[36m", actual)
217         self.assertIn("\033[32m", actual)
218         self.assertIn("\033[31m", actual)
219         self.assertIn("\033[0m", actual)
220
221     @patch("black.dump_to_file", dump_to_stderr)
222     def _test_wip(self) -> None:
223         source, expected = read_data("wip")
224         sys.settrace(tracefunc)
225         mode = replace(
226             DEFAULT_MODE,
227             experimental_string_processing=False,
228             target_versions={black.TargetVersion.PY38},
229         )
230         actual = fs(source, mode=mode)
231         sys.settrace(None)
232         self.assertFormatEqual(expected, actual)
233         black.assert_equivalent(source, actual)
234         black.assert_stable(source, actual, black.FileMode())
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability1(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens1")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability2(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens2")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @unittest.expectedFailure
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability3(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens3")
254         actual = fs(source)
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens1")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     @patch("black.dump_to_file", dump_to_stderr)
264     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
265         source, _expected = read_data("trailing_comma_optional_parens2")
266         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
267         black.assert_stable(source, actual, DEFAULT_MODE)
268
269     @patch("black.dump_to_file", dump_to_stderr)
270     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
271         source, _expected = read_data("trailing_comma_optional_parens3")
272         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
273         black.assert_stable(source, actual, DEFAULT_MODE)
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_pep_572(self) -> None:
277         source, expected = read_data("pep_572")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_stable(source, actual, DEFAULT_MODE)
281         if sys.version_info >= (3, 8):
282             black.assert_equivalent(source, actual)
283
284     @patch("black.dump_to_file", dump_to_stderr)
285     def test_pep_572_remove_parens(self) -> None:
286         source, expected = read_data("pep_572_remove_parens")
287         actual = fs(source)
288         self.assertFormatEqual(expected, actual)
289         black.assert_stable(source, actual, DEFAULT_MODE)
290         if sys.version_info >= (3, 8):
291             black.assert_equivalent(source, actual)
292
293     @patch("black.dump_to_file", dump_to_stderr)
294     def test_pep_572_do_not_remove_parens(self) -> None:
295         source, expected = read_data("pep_572_do_not_remove_parens")
296         # the AST safety checks will fail, but that's expected, just make sure no
297         # parentheses are touched
298         actual = black.format_str(source, mode=DEFAULT_MODE)
299         self.assertFormatEqual(expected, actual)
300
301     def test_pep_572_version_detection(self) -> None:
302         source, _ = read_data("pep_572")
303         root = black.lib2to3_parse(source)
304         features = black.get_features_used(root)
305         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
306         versions = black.detect_target_versions(root)
307         self.assertIn(black.TargetVersion.PY38, versions)
308
309     def test_expression_ff(self) -> None:
310         source, expected = read_data("expression")
311         tmp_file = Path(black.dump_to_file(source))
312         try:
313             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
314             with open(tmp_file, encoding="utf8") as f:
315                 actual = f.read()
316         finally:
317             os.unlink(tmp_file)
318         self.assertFormatEqual(expected, actual)
319         with patch("black.dump_to_file", dump_to_stderr):
320             black.assert_equivalent(source, actual)
321             black.assert_stable(source, actual, DEFAULT_MODE)
322
323     def test_expression_diff(self) -> None:
324         source, _ = read_data("expression.py")
325         config = THIS_DIR / "data" / "empty_pyproject.toml"
326         expected, _ = read_data("expression.diff")
327         tmp_file = Path(black.dump_to_file(source))
328         diff_header = re.compile(
329             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
330             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
331         )
332         try:
333             result = BlackRunner().invoke(
334                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
335             )
336             self.assertEqual(result.exit_code, 0)
337         finally:
338             os.unlink(tmp_file)
339         actual = result.output
340         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
341         if expected != actual:
342             dump = black.dump_to_file(actual)
343             msg = (
344                 "Expected diff isn't equal to the actual. If you made changes to"
345                 " expression.py and this is an anticipated difference, overwrite"
346                 f" tests/data/expression.diff with {dump}"
347             )
348             self.assertEqual(expected, actual, msg)
349
350     def test_expression_diff_with_color(self) -> None:
351         source, _ = read_data("expression.py")
352         config = THIS_DIR / "data" / "empty_pyproject.toml"
353         expected, _ = read_data("expression.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         try:
356             result = BlackRunner().invoke(
357                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
358             )
359         finally:
360             os.unlink(tmp_file)
361         actual = result.output
362         # We check the contents of the diff in `test_expression_diff`. All
363         # we need to check here is that color codes exist in the result.
364         self.assertIn("\033[1;37m", actual)
365         self.assertIn("\033[36m", actual)
366         self.assertIn("\033[32m", actual)
367         self.assertIn("\033[31m", actual)
368         self.assertIn("\033[0m", actual)
369
370     @patch("black.dump_to_file", dump_to_stderr)
371     def test_pep_570(self) -> None:
372         source, expected = read_data("pep_570")
373         actual = fs(source)
374         self.assertFormatEqual(expected, actual)
375         black.assert_stable(source, actual, DEFAULT_MODE)
376         if sys.version_info >= (3, 8):
377             black.assert_equivalent(source, actual)
378
379     def test_detect_pos_only_arguments(self) -> None:
380         source, _ = read_data("pep_570")
381         root = black.lib2to3_parse(source)
382         features = black.get_features_used(root)
383         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
384         versions = black.detect_target_versions(root)
385         self.assertIn(black.TargetVersion.PY38, versions)
386
387     @patch("black.dump_to_file", dump_to_stderr)
388     def test_string_quotes(self) -> None:
389         source, expected = read_data("string_quotes")
390         mode = black.Mode(experimental_string_processing=True)
391         actual = fs(source, mode=mode)
392         self.assertFormatEqual(expected, actual)
393         black.assert_equivalent(source, actual)
394         black.assert_stable(source, actual, mode)
395         mode = replace(mode, string_normalization=False)
396         not_normalized = fs(source, mode=mode)
397         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
398         black.assert_equivalent(source, not_normalized)
399         black.assert_stable(source, not_normalized, mode=mode)
400
401     @patch("black.dump_to_file", dump_to_stderr)
402     def test_docstring_no_string_normalization(self) -> None:
403         """Like test_docstring but with string normalization off."""
404         source, expected = read_data("docstring_no_string_normalization")
405         mode = replace(DEFAULT_MODE, string_normalization=False)
406         actual = fs(source, mode=mode)
407         self.assertFormatEqual(expected, actual)
408         black.assert_equivalent(source, actual)
409         black.assert_stable(source, actual, mode)
410
411     def test_long_strings_flag_disabled(self) -> None:
412         """Tests for turning off the string processing logic."""
413         source, expected = read_data("long_strings_flag_disabled")
414         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
415         actual = fs(source, mode=mode)
416         self.assertFormatEqual(expected, actual)
417         black.assert_stable(expected, actual, mode)
418
419     @patch("black.dump_to_file", dump_to_stderr)
420     def test_numeric_literals(self) -> None:
421         source, expected = read_data("numeric_literals")
422         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
423         actual = fs(source, mode=mode)
424         self.assertFormatEqual(expected, actual)
425         black.assert_equivalent(source, actual)
426         black.assert_stable(source, actual, mode)
427
428     @patch("black.dump_to_file", dump_to_stderr)
429     def test_numeric_literals_ignoring_underscores(self) -> None:
430         source, expected = read_data("numeric_literals_skip_underscores")
431         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
432         actual = fs(source, mode=mode)
433         self.assertFormatEqual(expected, actual)
434         black.assert_equivalent(source, actual)
435         black.assert_stable(source, actual, mode)
436
437     def test_skip_magic_trailing_comma(self) -> None:
438         source, _ = read_data("expression.py")
439         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
440         tmp_file = Path(black.dump_to_file(source))
441         diff_header = re.compile(
442             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
443             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
444         )
445         try:
446             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
447             self.assertEqual(result.exit_code, 0)
448         finally:
449             os.unlink(tmp_file)
450         actual = result.output
451         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
452         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
453         if expected != actual:
454             dump = black.dump_to_file(actual)
455             msg = (
456                 "Expected diff isn't equal to the actual. If you made changes to"
457                 " expression.py and this is an anticipated difference, overwrite"
458                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
459             )
460             self.assertEqual(expected, actual, msg)
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_python2_print_function(self) -> None:
464         source, expected = read_data("python2_print_function")
465         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
466         actual = fs(source, mode=mode)
467         self.assertFormatEqual(expected, actual)
468         black.assert_equivalent(source, actual)
469         black.assert_stable(source, actual, mode)
470
471     @patch("black.dump_to_file", dump_to_stderr)
472     def test_stub(self) -> None:
473         mode = replace(DEFAULT_MODE, is_pyi=True)
474         source, expected = read_data("stub.pyi")
475         actual = fs(source, mode=mode)
476         self.assertFormatEqual(expected, actual)
477         black.assert_stable(source, actual, mode)
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_async_as_identifier(self) -> None:
481         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
482         source, expected = read_data("async_as_identifier")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         major, minor = sys.version_info[:2]
486         if major < 3 or (major <= 3 and minor < 7):
487             black.assert_equivalent(source, actual)
488         black.assert_stable(source, actual, DEFAULT_MODE)
489         # ensure black can parse this when the target is 3.6
490         self.invokeBlack([str(source_path), "--target-version", "py36"])
491         # but not on 3.7, because async/await is no longer an identifier
492         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
493
494     @patch("black.dump_to_file", dump_to_stderr)
495     def test_python37(self) -> None:
496         source_path = (THIS_DIR / "data" / "python37.py").resolve()
497         source, expected = read_data("python37")
498         actual = fs(source)
499         self.assertFormatEqual(expected, actual)
500         major, minor = sys.version_info[:2]
501         if major > 3 or (major == 3 and minor >= 7):
502             black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, DEFAULT_MODE)
504         # ensure black can parse this when the target is 3.7
505         self.invokeBlack([str(source_path), "--target-version", "py37"])
506         # but not on 3.6, because we use async as a reserved keyword
507         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
508
509     @patch("black.dump_to_file", dump_to_stderr)
510     def test_python38(self) -> None:
511         source, expected = read_data("python38")
512         actual = fs(source)
513         self.assertFormatEqual(expected, actual)
514         major, minor = sys.version_info[:2]
515         if major > 3 or (major == 3 and minor >= 8):
516             black.assert_equivalent(source, actual)
517         black.assert_stable(source, actual, DEFAULT_MODE)
518
519     @patch("black.dump_to_file", dump_to_stderr)
520     def test_python39(self) -> None:
521         source, expected = read_data("python39")
522         actual = fs(source)
523         self.assertFormatEqual(expected, actual)
524         major, minor = sys.version_info[:2]
525         if major > 3 or (major == 3 and minor >= 9):
526             black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, DEFAULT_MODE)
528
529     def test_tab_comment_indentation(self) -> None:
530         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
531         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
532         self.assertFormatEqual(contents_spc, fs(contents_spc))
533         self.assertFormatEqual(contents_spc, fs(contents_tab))
534
535         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
536         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
537         self.assertFormatEqual(contents_spc, fs(contents_spc))
538         self.assertFormatEqual(contents_spc, fs(contents_tab))
539
540         # mixed tabs and spaces (valid Python 2 code)
541         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
542         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
543         self.assertFormatEqual(contents_spc, fs(contents_spc))
544         self.assertFormatEqual(contents_spc, fs(contents_tab))
545
546         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
547         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
548         self.assertFormatEqual(contents_spc, fs(contents_spc))
549         self.assertFormatEqual(contents_spc, fs(contents_tab))
550
551     def test_report_verbose(self) -> None:
552         report = black.Report(verbose=True)
553         out_lines = []
554         err_lines = []
555
556         def out(msg: str, **kwargs: Any) -> None:
557             out_lines.append(msg)
558
559         def err(msg: str, **kwargs: Any) -> None:
560             err_lines.append(msg)
561
562         with patch("black.out", out), patch("black.err", err):
563             report.done(Path("f1"), black.Changed.NO)
564             self.assertEqual(len(out_lines), 1)
565             self.assertEqual(len(err_lines), 0)
566             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
567             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
568             self.assertEqual(report.return_code, 0)
569             report.done(Path("f2"), black.Changed.YES)
570             self.assertEqual(len(out_lines), 2)
571             self.assertEqual(len(err_lines), 0)
572             self.assertEqual(out_lines[-1], "reformatted f2")
573             self.assertEqual(
574                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
575             )
576             report.done(Path("f3"), black.Changed.CACHED)
577             self.assertEqual(len(out_lines), 3)
578             self.assertEqual(len(err_lines), 0)
579             self.assertEqual(
580                 out_lines[-1], "f3 wasn't modified on disk since last run."
581             )
582             self.assertEqual(
583                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
584             )
585             self.assertEqual(report.return_code, 0)
586             report.check = True
587             self.assertEqual(report.return_code, 1)
588             report.check = False
589             report.failed(Path("e1"), "boom")
590             self.assertEqual(len(out_lines), 3)
591             self.assertEqual(len(err_lines), 1)
592             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
596                 " reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.done(Path("f3"), black.Changed.YES)
600             self.assertEqual(len(out_lines), 4)
601             self.assertEqual(len(err_lines), 1)
602             self.assertEqual(out_lines[-1], "reformatted f3")
603             self.assertEqual(
604                 unstyle(str(report)),
605                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
606                 " reformat.",
607             )
608             self.assertEqual(report.return_code, 123)
609             report.failed(Path("e2"), "boom")
610             self.assertEqual(len(out_lines), 4)
611             self.assertEqual(len(err_lines), 2)
612             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
613             self.assertEqual(
614                 unstyle(str(report)),
615                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
616                 " reformat.",
617             )
618             self.assertEqual(report.return_code, 123)
619             report.path_ignored(Path("wat"), "no match")
620             self.assertEqual(len(out_lines), 5)
621             self.assertEqual(len(err_lines), 2)
622             self.assertEqual(out_lines[-1], "wat ignored: no match")
623             self.assertEqual(
624                 unstyle(str(report)),
625                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
626                 " reformat.",
627             )
628             self.assertEqual(report.return_code, 123)
629             report.done(Path("f4"), black.Changed.NO)
630             self.assertEqual(len(out_lines), 6)
631             self.assertEqual(len(err_lines), 2)
632             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
633             self.assertEqual(
634                 unstyle(str(report)),
635                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
636                 " reformat.",
637             )
638             self.assertEqual(report.return_code, 123)
639             report.check = True
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
643                 " would fail to reformat.",
644             )
645             report.check = False
646             report.diff = True
647             self.assertEqual(
648                 unstyle(str(report)),
649                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
650                 " would fail to reformat.",
651             )
652
653     def test_report_quiet(self) -> None:
654         report = black.Report(quiet=True)
655         out_lines = []
656         err_lines = []
657
658         def out(msg: str, **kwargs: Any) -> None:
659             out_lines.append(msg)
660
661         def err(msg: str, **kwargs: Any) -> None:
662             err_lines.append(msg)
663
664         with patch("black.out", out), patch("black.err", err):
665             report.done(Path("f1"), black.Changed.NO)
666             self.assertEqual(len(out_lines), 0)
667             self.assertEqual(len(err_lines), 0)
668             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
669             self.assertEqual(report.return_code, 0)
670             report.done(Path("f2"), black.Changed.YES)
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 0)
673             self.assertEqual(
674                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
675             )
676             report.done(Path("f3"), black.Changed.CACHED)
677             self.assertEqual(len(out_lines), 0)
678             self.assertEqual(len(err_lines), 0)
679             self.assertEqual(
680                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
681             )
682             self.assertEqual(report.return_code, 0)
683             report.check = True
684             self.assertEqual(report.return_code, 1)
685             report.check = False
686             report.failed(Path("e1"), "boom")
687             self.assertEqual(len(out_lines), 0)
688             self.assertEqual(len(err_lines), 1)
689             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
693                 " reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.done(Path("f3"), black.Changed.YES)
697             self.assertEqual(len(out_lines), 0)
698             self.assertEqual(len(err_lines), 1)
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
702                 " reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.failed(Path("e2"), "boom")
706             self.assertEqual(len(out_lines), 0)
707             self.assertEqual(len(err_lines), 2)
708             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
709             self.assertEqual(
710                 unstyle(str(report)),
711                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
712                 " reformat.",
713             )
714             self.assertEqual(report.return_code, 123)
715             report.path_ignored(Path("wat"), "no match")
716             self.assertEqual(len(out_lines), 0)
717             self.assertEqual(len(err_lines), 2)
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
721                 " reformat.",
722             )
723             self.assertEqual(report.return_code, 123)
724             report.done(Path("f4"), black.Changed.NO)
725             self.assertEqual(len(out_lines), 0)
726             self.assertEqual(len(err_lines), 2)
727             self.assertEqual(
728                 unstyle(str(report)),
729                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
730                 " reformat.",
731             )
732             self.assertEqual(report.return_code, 123)
733             report.check = True
734             self.assertEqual(
735                 unstyle(str(report)),
736                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
737                 " would fail to reformat.",
738             )
739             report.check = False
740             report.diff = True
741             self.assertEqual(
742                 unstyle(str(report)),
743                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
744                 " would fail to reformat.",
745             )
746
747     def test_report_normal(self) -> None:
748         report = black.Report()
749         out_lines = []
750         err_lines = []
751
752         def out(msg: str, **kwargs: Any) -> None:
753             out_lines.append(msg)
754
755         def err(msg: str, **kwargs: Any) -> None:
756             err_lines.append(msg)
757
758         with patch("black.out", out), patch("black.err", err):
759             report.done(Path("f1"), black.Changed.NO)
760             self.assertEqual(len(out_lines), 0)
761             self.assertEqual(len(err_lines), 0)
762             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
763             self.assertEqual(report.return_code, 0)
764             report.done(Path("f2"), black.Changed.YES)
765             self.assertEqual(len(out_lines), 1)
766             self.assertEqual(len(err_lines), 0)
767             self.assertEqual(out_lines[-1], "reformatted f2")
768             self.assertEqual(
769                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
770             )
771             report.done(Path("f3"), black.Changed.CACHED)
772             self.assertEqual(len(out_lines), 1)
773             self.assertEqual(len(err_lines), 0)
774             self.assertEqual(out_lines[-1], "reformatted f2")
775             self.assertEqual(
776                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
777             )
778             self.assertEqual(report.return_code, 0)
779             report.check = True
780             self.assertEqual(report.return_code, 1)
781             report.check = False
782             report.failed(Path("e1"), "boom")
783             self.assertEqual(len(out_lines), 1)
784             self.assertEqual(len(err_lines), 1)
785             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
786             self.assertEqual(
787                 unstyle(str(report)),
788                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
789                 " reformat.",
790             )
791             self.assertEqual(report.return_code, 123)
792             report.done(Path("f3"), black.Changed.YES)
793             self.assertEqual(len(out_lines), 2)
794             self.assertEqual(len(err_lines), 1)
795             self.assertEqual(out_lines[-1], "reformatted f3")
796             self.assertEqual(
797                 unstyle(str(report)),
798                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
799                 " reformat.",
800             )
801             self.assertEqual(report.return_code, 123)
802             report.failed(Path("e2"), "boom")
803             self.assertEqual(len(out_lines), 2)
804             self.assertEqual(len(err_lines), 2)
805             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
806             self.assertEqual(
807                 unstyle(str(report)),
808                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
809                 " reformat.",
810             )
811             self.assertEqual(report.return_code, 123)
812             report.path_ignored(Path("wat"), "no match")
813             self.assertEqual(len(out_lines), 2)
814             self.assertEqual(len(err_lines), 2)
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
818                 " reformat.",
819             )
820             self.assertEqual(report.return_code, 123)
821             report.done(Path("f4"), black.Changed.NO)
822             self.assertEqual(len(out_lines), 2)
823             self.assertEqual(len(err_lines), 2)
824             self.assertEqual(
825                 unstyle(str(report)),
826                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
827                 " reformat.",
828             )
829             self.assertEqual(report.return_code, 123)
830             report.check = True
831             self.assertEqual(
832                 unstyle(str(report)),
833                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
834                 " would fail to reformat.",
835             )
836             report.check = False
837             report.diff = True
838             self.assertEqual(
839                 unstyle(str(report)),
840                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
841                 " would fail to reformat.",
842             )
843
844     def test_lib2to3_parse(self) -> None:
845         with self.assertRaises(black.InvalidInput):
846             black.lib2to3_parse("invalid syntax")
847
848         straddling = "x + y"
849         black.lib2to3_parse(straddling)
850         black.lib2to3_parse(straddling, {TargetVersion.PY27})
851         black.lib2to3_parse(straddling, {TargetVersion.PY36})
852         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
853
854         py2_only = "print x"
855         black.lib2to3_parse(py2_only)
856         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
857         with self.assertRaises(black.InvalidInput):
858             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
859         with self.assertRaises(black.InvalidInput):
860             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
861
862         py3_only = "exec(x, end=y)"
863         black.lib2to3_parse(py3_only)
864         with self.assertRaises(black.InvalidInput):
865             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
866         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
867         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
868
869     def test_get_features_used_decorator(self) -> None:
870         # Test the feature detection of new decorator syntax
871         # since this makes some test cases of test_get_features_used()
872         # fails if it fails, this is tested first so that a useful case
873         # is identified
874         simples, relaxed = read_data("decorators")
875         # skip explanation comments at the top of the file
876         for simple_test in simples.split("##")[1:]:
877             node = black.lib2to3_parse(simple_test)
878             decorator = str(node.children[0].children[0]).strip()
879             self.assertNotIn(
880                 Feature.RELAXED_DECORATORS,
881                 black.get_features_used(node),
882                 msg=(
883                     f"decorator '{decorator}' follows python<=3.8 syntax"
884                     "but is detected as 3.9+"
885                     # f"The full node is\n{node!r}"
886                 ),
887             )
888         # skip the '# output' comment at the top of the output part
889         for relaxed_test in relaxed.split("##")[1:]:
890             node = black.lib2to3_parse(relaxed_test)
891             decorator = str(node.children[0].children[0]).strip()
892             self.assertIn(
893                 Feature.RELAXED_DECORATORS,
894                 black.get_features_used(node),
895                 msg=(
896                     f"decorator '{decorator}' uses python3.9+ syntax"
897                     "but is detected as python<=3.8"
898                     # f"The full node is\n{node!r}"
899                 ),
900             )
901
902     def test_get_features_used(self) -> None:
903         node = black.lib2to3_parse("def f(*, arg): ...\n")
904         self.assertEqual(black.get_features_used(node), set())
905         node = black.lib2to3_parse("def f(*, arg,): ...\n")
906         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
907         node = black.lib2to3_parse("f(*arg,)\n")
908         self.assertEqual(
909             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
910         )
911         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
912         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
913         node = black.lib2to3_parse("123_456\n")
914         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
915         node = black.lib2to3_parse("123456\n")
916         self.assertEqual(black.get_features_used(node), set())
917         source, expected = read_data("function")
918         node = black.lib2to3_parse(source)
919         expected_features = {
920             Feature.TRAILING_COMMA_IN_CALL,
921             Feature.TRAILING_COMMA_IN_DEF,
922             Feature.F_STRINGS,
923         }
924         self.assertEqual(black.get_features_used(node), expected_features)
925         node = black.lib2to3_parse(expected)
926         self.assertEqual(black.get_features_used(node), expected_features)
927         source, expected = read_data("expression")
928         node = black.lib2to3_parse(source)
929         self.assertEqual(black.get_features_used(node), set())
930         node = black.lib2to3_parse(expected)
931         self.assertEqual(black.get_features_used(node), set())
932
933     def test_get_future_imports(self) -> None:
934         node = black.lib2to3_parse("\n")
935         self.assertEqual(set(), black.get_future_imports(node))
936         node = black.lib2to3_parse("from __future__ import black\n")
937         self.assertEqual({"black"}, black.get_future_imports(node))
938         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
939         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
940         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
941         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
942         node = black.lib2to3_parse(
943             "from __future__ import multiple\nfrom __future__ import imports\n"
944         )
945         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
946         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
947         self.assertEqual({"black"}, black.get_future_imports(node))
948         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
949         self.assertEqual({"black"}, black.get_future_imports(node))
950         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
951         self.assertEqual(set(), black.get_future_imports(node))
952         node = black.lib2to3_parse("from some.module import black\n")
953         self.assertEqual(set(), black.get_future_imports(node))
954         node = black.lib2to3_parse(
955             "from __future__ import unicode_literals as _unicode_literals"
956         )
957         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
958         node = black.lib2to3_parse(
959             "from __future__ import unicode_literals as _lol, print"
960         )
961         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
962
963     def test_debug_visitor(self) -> None:
964         source, _ = read_data("debug_visitor.py")
965         expected, _ = read_data("debug_visitor.out")
966         out_lines = []
967         err_lines = []
968
969         def out(msg: str, **kwargs: Any) -> None:
970             out_lines.append(msg)
971
972         def err(msg: str, **kwargs: Any) -> None:
973             err_lines.append(msg)
974
975         with patch("black.out", out), patch("black.err", err):
976             black.DebugVisitor.show(source)
977         actual = "\n".join(out_lines) + "\n"
978         log_name = ""
979         if expected != actual:
980             log_name = black.dump_to_file(*out_lines)
981         self.assertEqual(
982             expected,
983             actual,
984             f"AST print out is different. Actual version dumped to {log_name}",
985         )
986
987     def test_format_file_contents(self) -> None:
988         empty = ""
989         mode = DEFAULT_MODE
990         with self.assertRaises(black.NothingChanged):
991             black.format_file_contents(empty, mode=mode, fast=False)
992         just_nl = "\n"
993         with self.assertRaises(black.NothingChanged):
994             black.format_file_contents(just_nl, mode=mode, fast=False)
995         same = "j = [1, 2, 3]\n"
996         with self.assertRaises(black.NothingChanged):
997             black.format_file_contents(same, mode=mode, fast=False)
998         different = "j = [1,2,3]"
999         expected = same
1000         actual = black.format_file_contents(different, mode=mode, fast=False)
1001         self.assertEqual(expected, actual)
1002         invalid = "return if you can"
1003         with self.assertRaises(black.InvalidInput) as e:
1004             black.format_file_contents(invalid, mode=mode, fast=False)
1005         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1006
1007     def test_endmarker(self) -> None:
1008         n = black.lib2to3_parse("\n")
1009         self.assertEqual(n.type, black.syms.file_input)
1010         self.assertEqual(len(n.children), 1)
1011         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1012
1013     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1014     def test_assertFormatEqual(self) -> None:
1015         out_lines = []
1016         err_lines = []
1017
1018         def out(msg: str, **kwargs: Any) -> None:
1019             out_lines.append(msg)
1020
1021         def err(msg: str, **kwargs: Any) -> None:
1022             err_lines.append(msg)
1023
1024         with patch("black.out", out), patch("black.err", err):
1025             with self.assertRaises(AssertionError):
1026                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1027
1028         out_str = "".join(out_lines)
1029         self.assertTrue("Expected tree:" in out_str)
1030         self.assertTrue("Actual tree:" in out_str)
1031         self.assertEqual("".join(err_lines), "")
1032
1033     def test_cache_broken_file(self) -> None:
1034         mode = DEFAULT_MODE
1035         with cache_dir() as workspace:
1036             cache_file = black.get_cache_file(mode)
1037             with cache_file.open("w") as fobj:
1038                 fobj.write("this is not a pickle")
1039             self.assertEqual(black.read_cache(mode), {})
1040             src = (workspace / "test.py").resolve()
1041             with src.open("w") as fobj:
1042                 fobj.write("print('hello')")
1043             self.invokeBlack([str(src)])
1044             cache = black.read_cache(mode)
1045             self.assertIn(str(src), cache)
1046
1047     def test_cache_single_file_already_cached(self) -> None:
1048         mode = DEFAULT_MODE
1049         with cache_dir() as workspace:
1050             src = (workspace / "test.py").resolve()
1051             with src.open("w") as fobj:
1052                 fobj.write("print('hello')")
1053             black.write_cache({}, [src], mode)
1054             self.invokeBlack([str(src)])
1055             with src.open("r") as fobj:
1056                 self.assertEqual(fobj.read(), "print('hello')")
1057
1058     @event_loop()
1059     def test_cache_multiple_files(self) -> None:
1060         mode = DEFAULT_MODE
1061         with cache_dir() as workspace, patch(
1062             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1063         ):
1064             one = (workspace / "one.py").resolve()
1065             with one.open("w") as fobj:
1066                 fobj.write("print('hello')")
1067             two = (workspace / "two.py").resolve()
1068             with two.open("w") as fobj:
1069                 fobj.write("print('hello')")
1070             black.write_cache({}, [one], mode)
1071             self.invokeBlack([str(workspace)])
1072             with one.open("r") as fobj:
1073                 self.assertEqual(fobj.read(), "print('hello')")
1074             with two.open("r") as fobj:
1075                 self.assertEqual(fobj.read(), 'print("hello")\n')
1076             cache = black.read_cache(mode)
1077             self.assertIn(str(one), cache)
1078             self.assertIn(str(two), cache)
1079
1080     def test_no_cache_when_writeback_diff(self) -> None:
1081         mode = DEFAULT_MODE
1082         with cache_dir() as workspace:
1083             src = (workspace / "test.py").resolve()
1084             with src.open("w") as fobj:
1085                 fobj.write("print('hello')")
1086             with patch("black.read_cache") as read_cache, patch(
1087                 "black.write_cache"
1088             ) as write_cache:
1089                 self.invokeBlack([str(src), "--diff"])
1090                 cache_file = black.get_cache_file(mode)
1091                 self.assertFalse(cache_file.exists())
1092                 write_cache.assert_not_called()
1093                 read_cache.assert_not_called()
1094
1095     def test_no_cache_when_writeback_color_diff(self) -> None:
1096         mode = DEFAULT_MODE
1097         with cache_dir() as workspace:
1098             src = (workspace / "test.py").resolve()
1099             with src.open("w") as fobj:
1100                 fobj.write("print('hello')")
1101             with patch("black.read_cache") as read_cache, patch(
1102                 "black.write_cache"
1103             ) as write_cache:
1104                 self.invokeBlack([str(src), "--diff", "--color"])
1105                 cache_file = black.get_cache_file(mode)
1106                 self.assertFalse(cache_file.exists())
1107                 write_cache.assert_not_called()
1108                 read_cache.assert_not_called()
1109
1110     @event_loop()
1111     def test_output_locking_when_writeback_diff(self) -> None:
1112         with cache_dir() as workspace:
1113             for tag in range(0, 4):
1114                 src = (workspace / f"test{tag}.py").resolve()
1115                 with src.open("w") as fobj:
1116                     fobj.write("print('hello')")
1117             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1118                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1119                 # this isn't quite doing what we want, but if it _isn't_
1120                 # called then we cannot be using the lock it provides
1121                 mgr.assert_called()
1122
1123     @event_loop()
1124     def test_output_locking_when_writeback_color_diff(self) -> None:
1125         with cache_dir() as workspace:
1126             for tag in range(0, 4):
1127                 src = (workspace / f"test{tag}.py").resolve()
1128                 with src.open("w") as fobj:
1129                     fobj.write("print('hello')")
1130             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1131                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1132                 # this isn't quite doing what we want, but if it _isn't_
1133                 # called then we cannot be using the lock it provides
1134                 mgr.assert_called()
1135
1136     def test_no_cache_when_stdin(self) -> None:
1137         mode = DEFAULT_MODE
1138         with cache_dir():
1139             result = CliRunner().invoke(
1140                 black.main, ["-"], input=BytesIO(b"print('hello')")
1141             )
1142             self.assertEqual(result.exit_code, 0)
1143             cache_file = black.get_cache_file(mode)
1144             self.assertFalse(cache_file.exists())
1145
1146     def test_read_cache_no_cachefile(self) -> None:
1147         mode = DEFAULT_MODE
1148         with cache_dir():
1149             self.assertEqual(black.read_cache(mode), {})
1150
1151     def test_write_cache_read_cache(self) -> None:
1152         mode = DEFAULT_MODE
1153         with cache_dir() as workspace:
1154             src = (workspace / "test.py").resolve()
1155             src.touch()
1156             black.write_cache({}, [src], mode)
1157             cache = black.read_cache(mode)
1158             self.assertIn(str(src), cache)
1159             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1160
1161     def test_filter_cached(self) -> None:
1162         with TemporaryDirectory() as workspace:
1163             path = Path(workspace)
1164             uncached = (path / "uncached").resolve()
1165             cached = (path / "cached").resolve()
1166             cached_but_changed = (path / "changed").resolve()
1167             uncached.touch()
1168             cached.touch()
1169             cached_but_changed.touch()
1170             cache = {
1171                 str(cached): black.get_cache_info(cached),
1172                 str(cached_but_changed): (0.0, 0),
1173             }
1174             todo, done = black.filter_cached(
1175                 cache, {uncached, cached, cached_but_changed}
1176             )
1177             self.assertEqual(todo, {uncached, cached_but_changed})
1178             self.assertEqual(done, {cached})
1179
1180     def test_write_cache_creates_directory_if_needed(self) -> None:
1181         mode = DEFAULT_MODE
1182         with cache_dir(exists=False) as workspace:
1183             self.assertFalse(workspace.exists())
1184             black.write_cache({}, [], mode)
1185             self.assertTrue(workspace.exists())
1186
1187     @event_loop()
1188     def test_failed_formatting_does_not_get_cached(self) -> None:
1189         mode = DEFAULT_MODE
1190         with cache_dir() as workspace, patch(
1191             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1192         ):
1193             failing = (workspace / "failing.py").resolve()
1194             with failing.open("w") as fobj:
1195                 fobj.write("not actually python")
1196             clean = (workspace / "clean.py").resolve()
1197             with clean.open("w") as fobj:
1198                 fobj.write('print("hello")\n')
1199             self.invokeBlack([str(workspace)], exit_code=123)
1200             cache = black.read_cache(mode)
1201             self.assertNotIn(str(failing), cache)
1202             self.assertIn(str(clean), cache)
1203
1204     def test_write_cache_write_fail(self) -> None:
1205         mode = DEFAULT_MODE
1206         with cache_dir(), patch.object(Path, "open") as mock:
1207             mock.side_effect = OSError
1208             black.write_cache({}, [], mode)
1209
1210     @event_loop()
1211     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1212     def test_works_in_mono_process_only_environment(self) -> None:
1213         with cache_dir() as workspace:
1214             for f in [
1215                 (workspace / "one.py").resolve(),
1216                 (workspace / "two.py").resolve(),
1217             ]:
1218                 f.write_text('print("hello")\n')
1219             self.invokeBlack([str(workspace)])
1220
1221     @event_loop()
1222     def test_check_diff_use_together(self) -> None:
1223         with cache_dir():
1224             # Files which will be reformatted.
1225             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1226             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1227             # Files which will not be reformatted.
1228             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1229             self.invokeBlack([str(src2), "--diff", "--check"])
1230             # Multi file command.
1231             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1232
1233     def test_no_files(self) -> None:
1234         with cache_dir():
1235             # Without an argument, black exits with error code 0.
1236             self.invokeBlack([])
1237
1238     def test_broken_symlink(self) -> None:
1239         with cache_dir() as workspace:
1240             symlink = workspace / "broken_link.py"
1241             try:
1242                 symlink.symlink_to("nonexistent.py")
1243             except OSError as e:
1244                 self.skipTest(f"Can't create symlinks: {e}")
1245             self.invokeBlack([str(workspace.resolve())])
1246
1247     def test_read_cache_line_lengths(self) -> None:
1248         mode = DEFAULT_MODE
1249         short_mode = replace(DEFAULT_MODE, line_length=1)
1250         with cache_dir() as workspace:
1251             path = (workspace / "file.py").resolve()
1252             path.touch()
1253             black.write_cache({}, [path], mode)
1254             one = black.read_cache(mode)
1255             self.assertIn(str(path), one)
1256             two = black.read_cache(short_mode)
1257             self.assertNotIn(str(path), two)
1258
1259     def test_single_file_force_pyi(self) -> None:
1260         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1261         contents, expected = read_data("force_pyi")
1262         with cache_dir() as workspace:
1263             path = (workspace / "file.py").resolve()
1264             with open(path, "w") as fh:
1265                 fh.write(contents)
1266             self.invokeBlack([str(path), "--pyi"])
1267             with open(path, "r") as fh:
1268                 actual = fh.read()
1269             # verify cache with --pyi is separate
1270             pyi_cache = black.read_cache(pyi_mode)
1271             self.assertIn(str(path), pyi_cache)
1272             normal_cache = black.read_cache(DEFAULT_MODE)
1273             self.assertNotIn(str(path), normal_cache)
1274         self.assertFormatEqual(expected, actual)
1275         black.assert_equivalent(contents, actual)
1276         black.assert_stable(contents, actual, pyi_mode)
1277
1278     @event_loop()
1279     def test_multi_file_force_pyi(self) -> None:
1280         reg_mode = DEFAULT_MODE
1281         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1282         contents, expected = read_data("force_pyi")
1283         with cache_dir() as workspace:
1284             paths = [
1285                 (workspace / "file1.py").resolve(),
1286                 (workspace / "file2.py").resolve(),
1287             ]
1288             for path in paths:
1289                 with open(path, "w") as fh:
1290                     fh.write(contents)
1291             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1292             for path in paths:
1293                 with open(path, "r") as fh:
1294                     actual = fh.read()
1295                 self.assertEqual(actual, expected)
1296             # verify cache with --pyi is separate
1297             pyi_cache = black.read_cache(pyi_mode)
1298             normal_cache = black.read_cache(reg_mode)
1299             for path in paths:
1300                 self.assertIn(str(path), pyi_cache)
1301                 self.assertNotIn(str(path), normal_cache)
1302
1303     def test_pipe_force_pyi(self) -> None:
1304         source, expected = read_data("force_pyi")
1305         result = CliRunner().invoke(
1306             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1307         )
1308         self.assertEqual(result.exit_code, 0)
1309         actual = result.output
1310         self.assertFormatEqual(actual, expected)
1311
1312     def test_single_file_force_py36(self) -> None:
1313         reg_mode = DEFAULT_MODE
1314         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1315         source, expected = read_data("force_py36")
1316         with cache_dir() as workspace:
1317             path = (workspace / "file.py").resolve()
1318             with open(path, "w") as fh:
1319                 fh.write(source)
1320             self.invokeBlack([str(path), *PY36_ARGS])
1321             with open(path, "r") as fh:
1322                 actual = fh.read()
1323             # verify cache with --target-version is separate
1324             py36_cache = black.read_cache(py36_mode)
1325             self.assertIn(str(path), py36_cache)
1326             normal_cache = black.read_cache(reg_mode)
1327             self.assertNotIn(str(path), normal_cache)
1328         self.assertEqual(actual, expected)
1329
1330     @event_loop()
1331     def test_multi_file_force_py36(self) -> None:
1332         reg_mode = DEFAULT_MODE
1333         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1334         source, expected = read_data("force_py36")
1335         with cache_dir() as workspace:
1336             paths = [
1337                 (workspace / "file1.py").resolve(),
1338                 (workspace / "file2.py").resolve(),
1339             ]
1340             for path in paths:
1341                 with open(path, "w") as fh:
1342                     fh.write(source)
1343             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1344             for path in paths:
1345                 with open(path, "r") as fh:
1346                     actual = fh.read()
1347                 self.assertEqual(actual, expected)
1348             # verify cache with --target-version is separate
1349             pyi_cache = black.read_cache(py36_mode)
1350             normal_cache = black.read_cache(reg_mode)
1351             for path in paths:
1352                 self.assertIn(str(path), pyi_cache)
1353                 self.assertNotIn(str(path), normal_cache)
1354
1355     def test_pipe_force_py36(self) -> None:
1356         source, expected = read_data("force_py36")
1357         result = CliRunner().invoke(
1358             black.main,
1359             ["-", "-q", "--target-version=py36"],
1360             input=BytesIO(source.encode("utf8")),
1361         )
1362         self.assertEqual(result.exit_code, 0)
1363         actual = result.output
1364         self.assertFormatEqual(actual, expected)
1365
1366     def test_include_exclude(self) -> None:
1367         path = THIS_DIR / "data" / "include_exclude_tests"
1368         include = re.compile(r"\.pyi?$")
1369         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1370         report = black.Report()
1371         gitignore = PathSpec.from_lines("gitwildmatch", [])
1372         sources: List[Path] = []
1373         expected = [
1374             Path(path / "b/dont_exclude/a.py"),
1375             Path(path / "b/dont_exclude/a.pyi"),
1376         ]
1377         this_abs = THIS_DIR.resolve()
1378         sources.extend(
1379             black.gen_python_files(
1380                 path.iterdir(),
1381                 this_abs,
1382                 include,
1383                 exclude,
1384                 None,
1385                 None,
1386                 report,
1387                 gitignore,
1388             )
1389         )
1390         self.assertEqual(sorted(expected), sorted(sources))
1391
1392     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1393     def test_exclude_for_issue_1572(self) -> None:
1394         # Exclude shouldn't touch files that were explicitly given to Black through the
1395         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1396         # https://github.com/psf/black/issues/1572
1397         path = THIS_DIR / "data" / "include_exclude_tests"
1398         include = ""
1399         exclude = r"/exclude/|a\.py"
1400         src = str(path / "b/exclude/a.py")
1401         report = black.Report()
1402         expected = [Path(path / "b/exclude/a.py")]
1403         sources = list(
1404             black.get_sources(
1405                 ctx=FakeContext(),
1406                 src=(src,),
1407                 quiet=True,
1408                 verbose=False,
1409                 include=re.compile(include),
1410                 exclude=re.compile(exclude),
1411                 extend_exclude=None,
1412                 force_exclude=None,
1413                 report=report,
1414                 stdin_filename=None,
1415             )
1416         )
1417         self.assertEqual(sorted(expected), sorted(sources))
1418
1419     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1420     def test_get_sources_with_stdin(self) -> None:
1421         include = ""
1422         exclude = r"/exclude/|a\.py"
1423         src = "-"
1424         report = black.Report()
1425         expected = [Path("-")]
1426         sources = list(
1427             black.get_sources(
1428                 ctx=FakeContext(),
1429                 src=(src,),
1430                 quiet=True,
1431                 verbose=False,
1432                 include=re.compile(include),
1433                 exclude=re.compile(exclude),
1434                 extend_exclude=None,
1435                 force_exclude=None,
1436                 report=report,
1437                 stdin_filename=None,
1438             )
1439         )
1440         self.assertEqual(sorted(expected), sorted(sources))
1441
1442     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1443     def test_get_sources_with_stdin_filename(self) -> None:
1444         include = ""
1445         exclude = r"/exclude/|a\.py"
1446         src = "-"
1447         report = black.Report()
1448         stdin_filename = str(THIS_DIR / "data/collections.py")
1449         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1450         sources = list(
1451             black.get_sources(
1452                 ctx=FakeContext(),
1453                 src=(src,),
1454                 quiet=True,
1455                 verbose=False,
1456                 include=re.compile(include),
1457                 exclude=re.compile(exclude),
1458                 extend_exclude=None,
1459                 force_exclude=None,
1460                 report=report,
1461                 stdin_filename=stdin_filename,
1462             )
1463         )
1464         self.assertEqual(sorted(expected), sorted(sources))
1465
1466     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1467     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1468         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1469         # file being passed directly. This is the same as
1470         # test_exclude_for_issue_1572
1471         path = THIS_DIR / "data" / "include_exclude_tests"
1472         include = ""
1473         exclude = r"/exclude/|a\.py"
1474         src = "-"
1475         report = black.Report()
1476         stdin_filename = str(path / "b/exclude/a.py")
1477         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1478         sources = list(
1479             black.get_sources(
1480                 ctx=FakeContext(),
1481                 src=(src,),
1482                 quiet=True,
1483                 verbose=False,
1484                 include=re.compile(include),
1485                 exclude=re.compile(exclude),
1486                 extend_exclude=None,
1487                 force_exclude=None,
1488                 report=report,
1489                 stdin_filename=stdin_filename,
1490             )
1491         )
1492         self.assertEqual(sorted(expected), sorted(sources))
1493
1494     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1495     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1496         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1497         # file being passed directly. This is the same as
1498         # test_exclude_for_issue_1572
1499         path = THIS_DIR / "data" / "include_exclude_tests"
1500         include = ""
1501         extend_exclude = r"/exclude/|a\.py"
1502         src = "-"
1503         report = black.Report()
1504         stdin_filename = str(path / "b/exclude/a.py")
1505         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1506         sources = list(
1507             black.get_sources(
1508                 ctx=FakeContext(),
1509                 src=(src,),
1510                 quiet=True,
1511                 verbose=False,
1512                 include=re.compile(include),
1513                 exclude=re.compile(""),
1514                 extend_exclude=re.compile(extend_exclude),
1515                 force_exclude=None,
1516                 report=report,
1517                 stdin_filename=stdin_filename,
1518             )
1519         )
1520         self.assertEqual(sorted(expected), sorted(sources))
1521
1522     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1523     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1524         # Force exclude should exclude the file when passing it through
1525         # stdin_filename
1526         path = THIS_DIR / "data" / "include_exclude_tests"
1527         include = ""
1528         force_exclude = r"/exclude/|a\.py"
1529         src = "-"
1530         report = black.Report()
1531         stdin_filename = str(path / "b/exclude/a.py")
1532         sources = list(
1533             black.get_sources(
1534                 ctx=FakeContext(),
1535                 src=(src,),
1536                 quiet=True,
1537                 verbose=False,
1538                 include=re.compile(include),
1539                 exclude=re.compile(""),
1540                 extend_exclude=None,
1541                 force_exclude=re.compile(force_exclude),
1542                 report=report,
1543                 stdin_filename=stdin_filename,
1544             )
1545         )
1546         self.assertEqual([], sorted(sources))
1547
1548     def test_reformat_one_with_stdin(self) -> None:
1549         with patch(
1550             "black.format_stdin_to_stdout",
1551             return_value=lambda *args, **kwargs: black.Changed.YES,
1552         ) as fsts:
1553             report = MagicMock()
1554             path = Path("-")
1555             black.reformat_one(
1556                 path,
1557                 fast=True,
1558                 write_back=black.WriteBack.YES,
1559                 mode=DEFAULT_MODE,
1560                 report=report,
1561             )
1562             fsts.assert_called_once()
1563             report.done.assert_called_with(path, black.Changed.YES)
1564
1565     def test_reformat_one_with_stdin_filename(self) -> None:
1566         with patch(
1567             "black.format_stdin_to_stdout",
1568             return_value=lambda *args, **kwargs: black.Changed.YES,
1569         ) as fsts:
1570             report = MagicMock()
1571             p = "foo.py"
1572             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1573             expected = Path(p)
1574             black.reformat_one(
1575                 path,
1576                 fast=True,
1577                 write_back=black.WriteBack.YES,
1578                 mode=DEFAULT_MODE,
1579                 report=report,
1580             )
1581             fsts.assert_called_once()
1582             # __BLACK_STDIN_FILENAME__ should have been striped
1583             report.done.assert_called_with(expected, black.Changed.YES)
1584
1585     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1586         with patch(
1587             "black.format_stdin_to_stdout",
1588             return_value=lambda *args, **kwargs: black.Changed.YES,
1589         ) as fsts:
1590             report = MagicMock()
1591             # Even with an existing file, since we are forcing stdin, black
1592             # should output to stdout and not modify the file inplace
1593             p = Path(str(THIS_DIR / "data/collections.py"))
1594             # Make sure is_file actually returns True
1595             self.assertTrue(p.is_file())
1596             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1597             expected = Path(p)
1598             black.reformat_one(
1599                 path,
1600                 fast=True,
1601                 write_back=black.WriteBack.YES,
1602                 mode=DEFAULT_MODE,
1603                 report=report,
1604             )
1605             fsts.assert_called_once()
1606             # __BLACK_STDIN_FILENAME__ should have been striped
1607             report.done.assert_called_with(expected, black.Changed.YES)
1608
1609     def test_gitignore_exclude(self) -> None:
1610         path = THIS_DIR / "data" / "include_exclude_tests"
1611         include = re.compile(r"\.pyi?$")
1612         exclude = re.compile(r"")
1613         report = black.Report()
1614         gitignore = PathSpec.from_lines(
1615             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1616         )
1617         sources: List[Path] = []
1618         expected = [
1619             Path(path / "b/dont_exclude/a.py"),
1620             Path(path / "b/dont_exclude/a.pyi"),
1621         ]
1622         this_abs = THIS_DIR.resolve()
1623         sources.extend(
1624             black.gen_python_files(
1625                 path.iterdir(),
1626                 this_abs,
1627                 include,
1628                 exclude,
1629                 None,
1630                 None,
1631                 report,
1632                 gitignore,
1633             )
1634         )
1635         self.assertEqual(sorted(expected), sorted(sources))
1636
1637     def test_empty_include(self) -> None:
1638         path = THIS_DIR / "data" / "include_exclude_tests"
1639         report = black.Report()
1640         gitignore = PathSpec.from_lines("gitwildmatch", [])
1641         empty = re.compile(r"")
1642         sources: List[Path] = []
1643         expected = [
1644             Path(path / "b/exclude/a.pie"),
1645             Path(path / "b/exclude/a.py"),
1646             Path(path / "b/exclude/a.pyi"),
1647             Path(path / "b/dont_exclude/a.pie"),
1648             Path(path / "b/dont_exclude/a.py"),
1649             Path(path / "b/dont_exclude/a.pyi"),
1650             Path(path / "b/.definitely_exclude/a.pie"),
1651             Path(path / "b/.definitely_exclude/a.py"),
1652             Path(path / "b/.definitely_exclude/a.pyi"),
1653         ]
1654         this_abs = THIS_DIR.resolve()
1655         sources.extend(
1656             black.gen_python_files(
1657                 path.iterdir(),
1658                 this_abs,
1659                 empty,
1660                 re.compile(black.DEFAULT_EXCLUDES),
1661                 None,
1662                 None,
1663                 report,
1664                 gitignore,
1665             )
1666         )
1667         self.assertEqual(sorted(expected), sorted(sources))
1668
1669     def test_extend_exclude(self) -> None:
1670         path = THIS_DIR / "data" / "include_exclude_tests"
1671         report = black.Report()
1672         gitignore = PathSpec.from_lines("gitwildmatch", [])
1673         sources: List[Path] = []
1674         expected = [
1675             Path(path / "b/exclude/a.py"),
1676             Path(path / "b/dont_exclude/a.py"),
1677         ]
1678         this_abs = THIS_DIR.resolve()
1679         sources.extend(
1680             black.gen_python_files(
1681                 path.iterdir(),
1682                 this_abs,
1683                 re.compile(black.DEFAULT_INCLUDES),
1684                 re.compile(r"\.pyi$"),
1685                 re.compile(r"\.definitely_exclude"),
1686                 None,
1687                 report,
1688                 gitignore,
1689             )
1690         )
1691         self.assertEqual(sorted(expected), sorted(sources))
1692
1693     def test_invalid_cli_regex(self) -> None:
1694         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1695             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1696
1697     def test_preserves_line_endings(self) -> None:
1698         with TemporaryDirectory() as workspace:
1699             test_file = Path(workspace) / "test.py"
1700             for nl in ["\n", "\r\n"]:
1701                 contents = nl.join(["def f(  ):", "    pass"])
1702                 test_file.write_bytes(contents.encode())
1703                 ff(test_file, write_back=black.WriteBack.YES)
1704                 updated_contents: bytes = test_file.read_bytes()
1705                 self.assertIn(nl.encode(), updated_contents)
1706                 if nl == "\n":
1707                     self.assertNotIn(b"\r\n", updated_contents)
1708
1709     def test_preserves_line_endings_via_stdin(self) -> None:
1710         for nl in ["\n", "\r\n"]:
1711             contents = nl.join(["def f(  ):", "    pass"])
1712             runner = BlackRunner()
1713             result = runner.invoke(
1714                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1715             )
1716             self.assertEqual(result.exit_code, 0)
1717             output = runner.stdout_bytes
1718             self.assertIn(nl.encode("utf8"), output)
1719             if nl == "\n":
1720                 self.assertNotIn(b"\r\n", output)
1721
1722     def test_assert_equivalent_different_asts(self) -> None:
1723         with self.assertRaises(AssertionError):
1724             black.assert_equivalent("{}", "None")
1725
1726     def test_symlink_out_of_root_directory(self) -> None:
1727         path = MagicMock()
1728         root = THIS_DIR.resolve()
1729         child = MagicMock()
1730         include = re.compile(black.DEFAULT_INCLUDES)
1731         exclude = re.compile(black.DEFAULT_EXCLUDES)
1732         report = black.Report()
1733         gitignore = PathSpec.from_lines("gitwildmatch", [])
1734         # `child` should behave like a symlink which resolved path is clearly
1735         # outside of the `root` directory.
1736         path.iterdir.return_value = [child]
1737         child.resolve.return_value = Path("/a/b/c")
1738         child.as_posix.return_value = "/a/b/c"
1739         child.is_symlink.return_value = True
1740         try:
1741             list(
1742                 black.gen_python_files(
1743                     path.iterdir(),
1744                     root,
1745                     include,
1746                     exclude,
1747                     None,
1748                     None,
1749                     report,
1750                     gitignore,
1751                 )
1752             )
1753         except ValueError as ve:
1754             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1755         path.iterdir.assert_called_once()
1756         child.resolve.assert_called_once()
1757         child.is_symlink.assert_called_once()
1758         # `child` should behave like a strange file which resolved path is clearly
1759         # outside of the `root` directory.
1760         child.is_symlink.return_value = False
1761         with self.assertRaises(ValueError):
1762             list(
1763                 black.gen_python_files(
1764                     path.iterdir(),
1765                     root,
1766                     include,
1767                     exclude,
1768                     None,
1769                     None,
1770                     report,
1771                     gitignore,
1772                 )
1773             )
1774         path.iterdir.assert_called()
1775         self.assertEqual(path.iterdir.call_count, 2)
1776         child.resolve.assert_called()
1777         self.assertEqual(child.resolve.call_count, 2)
1778         child.is_symlink.assert_called()
1779         self.assertEqual(child.is_symlink.call_count, 2)
1780
1781     def test_shhh_click(self) -> None:
1782         try:
1783             from click import _unicodefun  # type: ignore
1784         except ModuleNotFoundError:
1785             self.skipTest("Incompatible Click version")
1786         if not hasattr(_unicodefun, "_verify_python3_env"):
1787             self.skipTest("Incompatible Click version")
1788         # First, let's see if Click is crashing with a preferred ASCII charset.
1789         with patch("locale.getpreferredencoding") as gpe:
1790             gpe.return_value = "ASCII"
1791             with self.assertRaises(RuntimeError):
1792                 _unicodefun._verify_python3_env()
1793         # Now, let's silence Click...
1794         black.patch_click()
1795         # ...and confirm it's silent.
1796         with patch("locale.getpreferredencoding") as gpe:
1797             gpe.return_value = "ASCII"
1798             try:
1799                 _unicodefun._verify_python3_env()
1800             except RuntimeError as re:
1801                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1802
1803     def test_root_logger_not_used_directly(self) -> None:
1804         def fail(*args: Any, **kwargs: Any) -> None:
1805             self.fail("Record created with root logger")
1806
1807         with patch.multiple(
1808             logging.root,
1809             debug=fail,
1810             info=fail,
1811             warning=fail,
1812             error=fail,
1813             critical=fail,
1814             log=fail,
1815         ):
1816             ff(THIS_FILE)
1817
1818     def test_invalid_config_return_code(self) -> None:
1819         tmp_file = Path(black.dump_to_file())
1820         try:
1821             tmp_config = Path(black.dump_to_file())
1822             tmp_config.unlink()
1823             args = ["--config", str(tmp_config), str(tmp_file)]
1824             self.invokeBlack(args, exit_code=2, ignore_config=False)
1825         finally:
1826             tmp_file.unlink()
1827
1828     def test_parse_pyproject_toml(self) -> None:
1829         test_toml_file = THIS_DIR / "test.toml"
1830         config = black.parse_pyproject_toml(str(test_toml_file))
1831         self.assertEqual(config["verbose"], 1)
1832         self.assertEqual(config["check"], "no")
1833         self.assertEqual(config["diff"], "y")
1834         self.assertEqual(config["color"], True)
1835         self.assertEqual(config["line_length"], 79)
1836         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1837         self.assertEqual(config["exclude"], r"\.pyi?$")
1838         self.assertEqual(config["include"], r"\.py?$")
1839
1840     def test_read_pyproject_toml(self) -> None:
1841         test_toml_file = THIS_DIR / "test.toml"
1842         fake_ctx = FakeContext()
1843         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1844         config = fake_ctx.default_map
1845         self.assertEqual(config["verbose"], "1")
1846         self.assertEqual(config["check"], "no")
1847         self.assertEqual(config["diff"], "y")
1848         self.assertEqual(config["color"], "True")
1849         self.assertEqual(config["line_length"], "79")
1850         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1851         self.assertEqual(config["exclude"], r"\.pyi?$")
1852         self.assertEqual(config["include"], r"\.py?$")
1853
1854     def test_find_project_root(self) -> None:
1855         with TemporaryDirectory() as workspace:
1856             root = Path(workspace)
1857             test_dir = root / "test"
1858             test_dir.mkdir()
1859
1860             src_dir = root / "src"
1861             src_dir.mkdir()
1862
1863             root_pyproject = root / "pyproject.toml"
1864             root_pyproject.touch()
1865             src_pyproject = src_dir / "pyproject.toml"
1866             src_pyproject.touch()
1867             src_python = src_dir / "foo.py"
1868             src_python.touch()
1869
1870             self.assertEqual(
1871                 black.find_project_root((src_dir, test_dir)), root.resolve()
1872             )
1873             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1874             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1875
1876     @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
1877     def test_find_user_pyproject_toml_linux(self) -> None:
1878         if system() == "Windows":
1879             return
1880
1881         # Test if XDG_CONFIG_HOME is checked
1882         with TemporaryDirectory() as workspace:
1883             tmp_user_config = Path(workspace) / "black"
1884             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1885                 self.assertEqual(
1886                     black.find_user_pyproject_toml(), tmp_user_config.resolve()
1887                 )
1888
1889         # Test fallback for XDG_CONFIG_HOME
1890         with patch.dict("os.environ"):
1891             os.environ.pop("XDG_CONFIG_HOME", None)
1892             fallback_user_config = Path("~/.config").expanduser() / "black"
1893             self.assertEqual(
1894                 black.find_user_pyproject_toml(), fallback_user_config.resolve()
1895             )
1896
1897     def test_find_user_pyproject_toml_windows(self) -> None:
1898         if system() != "Windows":
1899             return
1900
1901         user_config_path = Path.home() / ".black"
1902         self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
1903
1904     def test_bpo_33660_workaround(self) -> None:
1905         if system() == "Windows":
1906             return
1907
1908         # https://bugs.python.org/issue33660
1909
1910         old_cwd = Path.cwd()
1911         try:
1912             root = Path("/")
1913             os.chdir(str(root))
1914             path = Path("workspace") / "project"
1915             report = black.Report(verbose=True)
1916             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1917             self.assertEqual(normalized_path, "workspace/project")
1918         finally:
1919             os.chdir(str(old_cwd))
1920
1921     def test_newline_comment_interaction(self) -> None:
1922         source = "class A:\\\r\n# type: ignore\n pass\n"
1923         output = black.format_str(source, mode=DEFAULT_MODE)
1924         black.assert_stable(source, output, mode=DEFAULT_MODE)
1925
1926     def test_bpo_2142_workaround(self) -> None:
1927
1928         # https://bugs.python.org/issue2142
1929
1930         source, _ = read_data("missing_final_newline.py")
1931         # read_data adds a trailing newline
1932         source = source.rstrip()
1933         expected, _ = read_data("missing_final_newline.diff")
1934         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1935         diff_header = re.compile(
1936             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1937             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1938         )
1939         try:
1940             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1941             self.assertEqual(result.exit_code, 0)
1942         finally:
1943             os.unlink(tmp_file)
1944         actual = result.output
1945         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1946         self.assertEqual(actual, expected)
1947
1948     def test_docstring_reformat_for_py27(self) -> None:
1949         """
1950         Check that stripping trailing whitespace from Python 2 docstrings
1951         doesn't trigger a "not equivalent to source" error
1952         """
1953         source = (
1954             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1955         )
1956         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1957
1958         result = CliRunner().invoke(
1959             black.main,
1960             ["-", "-q", "--target-version=py27"],
1961             input=BytesIO(source),
1962         )
1963
1964         self.assertEqual(result.exit_code, 0)
1965         actual = result.output
1966         self.assertFormatEqual(actual, expected)
1967
1968
1969 with open(black.__file__, "r", encoding="utf-8") as _bf:
1970     black_source_lines = _bf.readlines()
1971
1972
1973 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
1974     """Show function calls `from black/__init__.py` as they happen.
1975
1976     Register this with `sys.settrace()` in a test you're debugging.
1977     """
1978     if event != "call":
1979         return tracefunc
1980
1981     stack = len(inspect.stack()) - 19
1982     stack *= 2
1983     filename = frame.f_code.co_filename
1984     lineno = frame.f_lineno
1985     func_sig_lineno = lineno - 1
1986     funcname = black_source_lines[func_sig_lineno].strip()
1987     while funcname.startswith("@"):
1988         func_sig_lineno += 1
1989         funcname = black_source_lines[func_sig_lineno].strip()
1990     if "black/__init__.py" in filename:
1991         print(f"{' ' * stack}{lineno}:{funcname}")
1992     return tracefunc
1993
1994
1995 if __name__ == "__main__":
1996     unittest.main(module="test_black")