]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Docker CI: Add missed Checkout step (#2128)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 from pathspec import PathSpec
38
39 # Import other test classes
40 from tests.util import (
41     THIS_DIR,
42     read_data,
43     DETERMINISTIC_HEADER,
44     BlackBaseTestCase,
45     DEFAULT_MODE,
46     fs,
47     ff,
48     dump_to_stderr,
49 )
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 THIS_FILE = Path(__file__)
54 PY36_VERSIONS = {
55     TargetVersion.PY36,
56     TargetVersion.PY37,
57     TargetVersion.PY38,
58     TargetVersion.PY39,
59 }
60 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
61 T = TypeVar("T")
62 R = TypeVar("R")
63
64
65 @contextmanager
66 def cache_dir(exists: bool = True) -> Iterator[Path]:
67     with TemporaryDirectory() as workspace:
68         cache_dir = Path(workspace)
69         if not exists:
70             cache_dir = cache_dir / "new"
71         with patch("black.CACHE_DIR", cache_dir):
72             yield cache_dir
73
74
75 @contextmanager
76 def event_loop() -> Iterator[None]:
77     policy = asyncio.get_event_loop_policy()
78     loop = policy.new_event_loop()
79     asyncio.set_event_loop(loop)
80     try:
81         yield
82
83     finally:
84         loop.close()
85
86
87 class FakeContext(click.Context):
88     """A fake click Context for when calling functions that need it."""
89
90     def __init__(self) -> None:
91         self.default_map: Dict[str, Any] = {}
92
93
94 class FakeParameter(click.Parameter):
95     """A fake click Parameter for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         pass
99
100
101 class BlackRunner(CliRunner):
102     """Modify CliRunner so that stderr is not merged with stdout.
103
104     This is a hack that can be removed once we depend on Click 7.x"""
105
106     def __init__(self) -> None:
107         self.stderrbuf = BytesIO()
108         self.stdoutbuf = BytesIO()
109         self.stdout_bytes = b""
110         self.stderr_bytes = b""
111         super().__init__()
112
113     @contextmanager
114     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
115         with super().isolation(*args, **kwargs) as output:
116             try:
117                 hold_stderr = sys.stderr
118                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
119                 yield output
120             finally:
121                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
122                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
123                 sys.stderr = hold_stderr
124
125
126 class BlackTestCase(BlackBaseTestCase):
127     def invokeBlack(
128         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
129     ) -> None:
130         runner = BlackRunner()
131         if ignore_config:
132             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
133         result = runner.invoke(black.main, args)
134         self.assertEqual(
135             result.exit_code,
136             exit_code,
137             msg=(
138                 f"Failed with args: {args}\n"
139                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
140                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
141                 f"exception: {result.exception}"
142             ),
143         )
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_empty(self) -> None:
147         source = expected = ""
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, DEFAULT_MODE)
152
153     def test_empty_ff(self) -> None:
154         expected = ""
155         tmp_file = Path(black.dump_to_file())
156         try:
157             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
158             with open(tmp_file, encoding="utf8") as f:
159                 actual = f.read()
160         finally:
161             os.unlink(tmp_file)
162         self.assertFormatEqual(expected, actual)
163
164     def test_piping(self) -> None:
165         source, expected = read_data("src/black/__init__", data=False)
166         result = BlackRunner().invoke(
167             black.main,
168             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         black.assert_equivalent(source, result.output)
174         black.assert_stable(source, result.output, DEFAULT_MODE)
175
176     def test_piping_diff(self) -> None:
177         diff_header = re.compile(
178             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
179             r"\+\d\d\d\d"
180         )
181         source, _ = read_data("expression.py")
182         expected, _ = read_data("expression.diff")
183         config = THIS_DIR / "data" / "empty_pyproject.toml"
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={config}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         config = THIS_DIR / "data" / "empty_pyproject.toml"
202         args = [
203             "-",
204             "--fast",
205             f"--line-length={black.DEFAULT_LINE_LENGTH}",
206             "--diff",
207             "--color",
208             f"--config={config}",
209         ]
210         result = BlackRunner().invoke(
211             black.main, args, input=BytesIO(source.encode("utf8"))
212         )
213         actual = result.output
214         # Again, the contents are checked in a different test, so only look for colors.
215         self.assertIn("\033[1;37m", actual)
216         self.assertIn("\033[36m", actual)
217         self.assertIn("\033[32m", actual)
218         self.assertIn("\033[31m", actual)
219         self.assertIn("\033[0m", actual)
220
221     @patch("black.dump_to_file", dump_to_stderr)
222     def _test_wip(self) -> None:
223         source, expected = read_data("wip")
224         sys.settrace(tracefunc)
225         mode = replace(
226             DEFAULT_MODE,
227             experimental_string_processing=False,
228             target_versions={black.TargetVersion.PY38},
229         )
230         actual = fs(source, mode=mode)
231         sys.settrace(None)
232         self.assertFormatEqual(expected, actual)
233         black.assert_equivalent(source, actual)
234         black.assert_stable(source, actual, black.FileMode())
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability1(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens1")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability2(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens2")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @unittest.expectedFailure
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability3(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens3")
254         actual = fs(source)
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens1")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     @patch("black.dump_to_file", dump_to_stderr)
264     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
265         source, _expected = read_data("trailing_comma_optional_parens2")
266         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
267         black.assert_stable(source, actual, DEFAULT_MODE)
268
269     @patch("black.dump_to_file", dump_to_stderr)
270     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
271         source, _expected = read_data("trailing_comma_optional_parens3")
272         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
273         black.assert_stable(source, actual, DEFAULT_MODE)
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_pep_572(self) -> None:
277         source, expected = read_data("pep_572")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_stable(source, actual, DEFAULT_MODE)
281         if sys.version_info >= (3, 8):
282             black.assert_equivalent(source, actual)
283
284     @patch("black.dump_to_file", dump_to_stderr)
285     def test_pep_572_remove_parens(self) -> None:
286         source, expected = read_data("pep_572_remove_parens")
287         actual = fs(source)
288         self.assertFormatEqual(expected, actual)
289         black.assert_stable(source, actual, DEFAULT_MODE)
290         if sys.version_info >= (3, 8):
291             black.assert_equivalent(source, actual)
292
293     def test_pep_572_version_detection(self) -> None:
294         source, _ = read_data("pep_572")
295         root = black.lib2to3_parse(source)
296         features = black.get_features_used(root)
297         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
298         versions = black.detect_target_versions(root)
299         self.assertIn(black.TargetVersion.PY38, versions)
300
301     def test_expression_ff(self) -> None:
302         source, expected = read_data("expression")
303         tmp_file = Path(black.dump_to_file(source))
304         try:
305             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
306             with open(tmp_file, encoding="utf8") as f:
307                 actual = f.read()
308         finally:
309             os.unlink(tmp_file)
310         self.assertFormatEqual(expected, actual)
311         with patch("black.dump_to_file", dump_to_stderr):
312             black.assert_equivalent(source, actual)
313             black.assert_stable(source, actual, DEFAULT_MODE)
314
315     def test_expression_diff(self) -> None:
316         source, _ = read_data("expression.py")
317         config = THIS_DIR / "data" / "empty_pyproject.toml"
318         expected, _ = read_data("expression.diff")
319         tmp_file = Path(black.dump_to_file(source))
320         diff_header = re.compile(
321             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
322             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
323         )
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
327             )
328             self.assertEqual(result.exit_code, 0)
329         finally:
330             os.unlink(tmp_file)
331         actual = result.output
332         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
333         if expected != actual:
334             dump = black.dump_to_file(actual)
335             msg = (
336                 "Expected diff isn't equal to the actual. If you made changes to"
337                 " expression.py and this is an anticipated difference, overwrite"
338                 f" tests/data/expression.diff with {dump}"
339             )
340             self.assertEqual(expected, actual, msg)
341
342     def test_expression_diff_with_color(self) -> None:
343         source, _ = read_data("expression.py")
344         config = THIS_DIR / "data" / "empty_pyproject.toml"
345         expected, _ = read_data("expression.diff")
346         tmp_file = Path(black.dump_to_file(source))
347         try:
348             result = BlackRunner().invoke(
349                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
350             )
351         finally:
352             os.unlink(tmp_file)
353         actual = result.output
354         # We check the contents of the diff in `test_expression_diff`. All
355         # we need to check here is that color codes exist in the result.
356         self.assertIn("\033[1;37m", actual)
357         self.assertIn("\033[36m", actual)
358         self.assertIn("\033[32m", actual)
359         self.assertIn("\033[31m", actual)
360         self.assertIn("\033[0m", actual)
361
362     @patch("black.dump_to_file", dump_to_stderr)
363     def test_pep_570(self) -> None:
364         source, expected = read_data("pep_570")
365         actual = fs(source)
366         self.assertFormatEqual(expected, actual)
367         black.assert_stable(source, actual, DEFAULT_MODE)
368         if sys.version_info >= (3, 8):
369             black.assert_equivalent(source, actual)
370
371     def test_detect_pos_only_arguments(self) -> None:
372         source, _ = read_data("pep_570")
373         root = black.lib2to3_parse(source)
374         features = black.get_features_used(root)
375         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
376         versions = black.detect_target_versions(root)
377         self.assertIn(black.TargetVersion.PY38, versions)
378
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_string_quotes(self) -> None:
381         source, expected = read_data("string_quotes")
382         actual = fs(source)
383         self.assertFormatEqual(expected, actual)
384         black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         mode = replace(DEFAULT_MODE, string_normalization=False)
387         not_normalized = fs(source, mode=mode)
388         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
389         black.assert_equivalent(source, not_normalized)
390         black.assert_stable(source, not_normalized, mode=mode)
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_docstring_no_string_normalization(self) -> None:
394         """Like test_docstring but with string normalization off."""
395         source, expected = read_data("docstring_no_string_normalization")
396         mode = replace(DEFAULT_MODE, string_normalization=False)
397         actual = fs(source, mode=mode)
398         self.assertFormatEqual(expected, actual)
399         black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, mode)
401
402     def test_long_strings_flag_disabled(self) -> None:
403         """Tests for turning off the string processing logic."""
404         source, expected = read_data("long_strings_flag_disabled")
405         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
406         actual = fs(source, mode=mode)
407         self.assertFormatEqual(expected, actual)
408         black.assert_stable(expected, actual, mode)
409
410     @patch("black.dump_to_file", dump_to_stderr)
411     def test_numeric_literals(self) -> None:
412         source, expected = read_data("numeric_literals")
413         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
414         actual = fs(source, mode=mode)
415         self.assertFormatEqual(expected, actual)
416         black.assert_equivalent(source, actual)
417         black.assert_stable(source, actual, mode)
418
419     @patch("black.dump_to_file", dump_to_stderr)
420     def test_numeric_literals_ignoring_underscores(self) -> None:
421         source, expected = read_data("numeric_literals_skip_underscores")
422         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
423         actual = fs(source, mode=mode)
424         self.assertFormatEqual(expected, actual)
425         black.assert_equivalent(source, actual)
426         black.assert_stable(source, actual, mode)
427
428     def test_skip_magic_trailing_comma(self) -> None:
429         source, _ = read_data("expression.py")
430         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
431         tmp_file = Path(black.dump_to_file(source))
432         diff_header = re.compile(
433             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
434             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
435         )
436         try:
437             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
438             self.assertEqual(result.exit_code, 0)
439         finally:
440             os.unlink(tmp_file)
441         actual = result.output
442         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
443         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
444         if expected != actual:
445             dump = black.dump_to_file(actual)
446             msg = (
447                 "Expected diff isn't equal to the actual. If you made changes to"
448                 " expression.py and this is an anticipated difference, overwrite"
449                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
450             )
451             self.assertEqual(expected, actual, msg)
452
453     @patch("black.dump_to_file", dump_to_stderr)
454     def test_python2_print_function(self) -> None:
455         source, expected = read_data("python2_print_function")
456         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
457         actual = fs(source, mode=mode)
458         self.assertFormatEqual(expected, actual)
459         black.assert_equivalent(source, actual)
460         black.assert_stable(source, actual, mode)
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_stub(self) -> None:
464         mode = replace(DEFAULT_MODE, is_pyi=True)
465         source, expected = read_data("stub.pyi")
466         actual = fs(source, mode=mode)
467         self.assertFormatEqual(expected, actual)
468         black.assert_stable(source, actual, mode)
469
470     @patch("black.dump_to_file", dump_to_stderr)
471     def test_async_as_identifier(self) -> None:
472         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
473         source, expected = read_data("async_as_identifier")
474         actual = fs(source)
475         self.assertFormatEqual(expected, actual)
476         major, minor = sys.version_info[:2]
477         if major < 3 or (major <= 3 and minor < 7):
478             black.assert_equivalent(source, actual)
479         black.assert_stable(source, actual, DEFAULT_MODE)
480         # ensure black can parse this when the target is 3.6
481         self.invokeBlack([str(source_path), "--target-version", "py36"])
482         # but not on 3.7, because async/await is no longer an identifier
483         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
484
485     @patch("black.dump_to_file", dump_to_stderr)
486     def test_python37(self) -> None:
487         source_path = (THIS_DIR / "data" / "python37.py").resolve()
488         source, expected = read_data("python37")
489         actual = fs(source)
490         self.assertFormatEqual(expected, actual)
491         major, minor = sys.version_info[:2]
492         if major > 3 or (major == 3 and minor >= 7):
493             black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, DEFAULT_MODE)
495         # ensure black can parse this when the target is 3.7
496         self.invokeBlack([str(source_path), "--target-version", "py37"])
497         # but not on 3.6, because we use async as a reserved keyword
498         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
499
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_python38(self) -> None:
502         source, expected = read_data("python38")
503         actual = fs(source)
504         self.assertFormatEqual(expected, actual)
505         major, minor = sys.version_info[:2]
506         if major > 3 or (major == 3 and minor >= 8):
507             black.assert_equivalent(source, actual)
508         black.assert_stable(source, actual, DEFAULT_MODE)
509
510     @patch("black.dump_to_file", dump_to_stderr)
511     def test_python39(self) -> None:
512         source, expected = read_data("python39")
513         actual = fs(source)
514         self.assertFormatEqual(expected, actual)
515         major, minor = sys.version_info[:2]
516         if major > 3 or (major == 3 and minor >= 9):
517             black.assert_equivalent(source, actual)
518         black.assert_stable(source, actual, DEFAULT_MODE)
519
520     def test_tab_comment_indentation(self) -> None:
521         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
522         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
523         self.assertFormatEqual(contents_spc, fs(contents_spc))
524         self.assertFormatEqual(contents_spc, fs(contents_tab))
525
526         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
527         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
528         self.assertFormatEqual(contents_spc, fs(contents_spc))
529         self.assertFormatEqual(contents_spc, fs(contents_tab))
530
531         # mixed tabs and spaces (valid Python 2 code)
532         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
533         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
534         self.assertFormatEqual(contents_spc, fs(contents_spc))
535         self.assertFormatEqual(contents_spc, fs(contents_tab))
536
537         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
538         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
539         self.assertFormatEqual(contents_spc, fs(contents_spc))
540         self.assertFormatEqual(contents_spc, fs(contents_tab))
541
542     def test_report_verbose(self) -> None:
543         report = black.Report(verbose=True)
544         out_lines = []
545         err_lines = []
546
547         def out(msg: str, **kwargs: Any) -> None:
548             out_lines.append(msg)
549
550         def err(msg: str, **kwargs: Any) -> None:
551             err_lines.append(msg)
552
553         with patch("black.out", out), patch("black.err", err):
554             report.done(Path("f1"), black.Changed.NO)
555             self.assertEqual(len(out_lines), 1)
556             self.assertEqual(len(err_lines), 0)
557             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
558             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
559             self.assertEqual(report.return_code, 0)
560             report.done(Path("f2"), black.Changed.YES)
561             self.assertEqual(len(out_lines), 2)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(out_lines[-1], "reformatted f2")
564             self.assertEqual(
565                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
566             )
567             report.done(Path("f3"), black.Changed.CACHED)
568             self.assertEqual(len(out_lines), 3)
569             self.assertEqual(len(err_lines), 0)
570             self.assertEqual(
571                 out_lines[-1], "f3 wasn't modified on disk since last run."
572             )
573             self.assertEqual(
574                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
575             )
576             self.assertEqual(report.return_code, 0)
577             report.check = True
578             self.assertEqual(report.return_code, 1)
579             report.check = False
580             report.failed(Path("e1"), "boom")
581             self.assertEqual(len(out_lines), 3)
582             self.assertEqual(len(err_lines), 1)
583             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
587                 " reformat.",
588             )
589             self.assertEqual(report.return_code, 123)
590             report.done(Path("f3"), black.Changed.YES)
591             self.assertEqual(len(out_lines), 4)
592             self.assertEqual(len(err_lines), 1)
593             self.assertEqual(out_lines[-1], "reformatted f3")
594             self.assertEqual(
595                 unstyle(str(report)),
596                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
597                 " reformat.",
598             )
599             self.assertEqual(report.return_code, 123)
600             report.failed(Path("e2"), "boom")
601             self.assertEqual(len(out_lines), 4)
602             self.assertEqual(len(err_lines), 2)
603             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.path_ignored(Path("wat"), "no match")
611             self.assertEqual(len(out_lines), 5)
612             self.assertEqual(len(err_lines), 2)
613             self.assertEqual(out_lines[-1], "wat ignored: no match")
614             self.assertEqual(
615                 unstyle(str(report)),
616                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
617                 " reformat.",
618             )
619             self.assertEqual(report.return_code, 123)
620             report.done(Path("f4"), black.Changed.NO)
621             self.assertEqual(len(out_lines), 6)
622             self.assertEqual(len(err_lines), 2)
623             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
627                 " reformat.",
628             )
629             self.assertEqual(report.return_code, 123)
630             report.check = True
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
634                 " would fail to reformat.",
635             )
636             report.check = False
637             report.diff = True
638             self.assertEqual(
639                 unstyle(str(report)),
640                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
641                 " would fail to reformat.",
642             )
643
644     def test_report_quiet(self) -> None:
645         report = black.Report(quiet=True)
646         out_lines = []
647         err_lines = []
648
649         def out(msg: str, **kwargs: Any) -> None:
650             out_lines.append(msg)
651
652         def err(msg: str, **kwargs: Any) -> None:
653             err_lines.append(msg)
654
655         with patch("black.out", out), patch("black.err", err):
656             report.done(Path("f1"), black.Changed.NO)
657             self.assertEqual(len(out_lines), 0)
658             self.assertEqual(len(err_lines), 0)
659             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
660             self.assertEqual(report.return_code, 0)
661             report.done(Path("f2"), black.Changed.YES)
662             self.assertEqual(len(out_lines), 0)
663             self.assertEqual(len(err_lines), 0)
664             self.assertEqual(
665                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
666             )
667             report.done(Path("f3"), black.Changed.CACHED)
668             self.assertEqual(len(out_lines), 0)
669             self.assertEqual(len(err_lines), 0)
670             self.assertEqual(
671                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
672             )
673             self.assertEqual(report.return_code, 0)
674             report.check = True
675             self.assertEqual(report.return_code, 1)
676             report.check = False
677             report.failed(Path("e1"), "boom")
678             self.assertEqual(len(out_lines), 0)
679             self.assertEqual(len(err_lines), 1)
680             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
681             self.assertEqual(
682                 unstyle(str(report)),
683                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
684                 " reformat.",
685             )
686             self.assertEqual(report.return_code, 123)
687             report.done(Path("f3"), black.Changed.YES)
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 1)
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
693                 " reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.failed(Path("e2"), "boom")
697             self.assertEqual(len(out_lines), 0)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
703                 " reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.path_ignored(Path("wat"), "no match")
707             self.assertEqual(len(out_lines), 0)
708             self.assertEqual(len(err_lines), 2)
709             self.assertEqual(
710                 unstyle(str(report)),
711                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
712                 " reformat.",
713             )
714             self.assertEqual(report.return_code, 123)
715             report.done(Path("f4"), black.Changed.NO)
716             self.assertEqual(len(out_lines), 0)
717             self.assertEqual(len(err_lines), 2)
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
721                 " reformat.",
722             )
723             self.assertEqual(report.return_code, 123)
724             report.check = True
725             self.assertEqual(
726                 unstyle(str(report)),
727                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
728                 " would fail to reformat.",
729             )
730             report.check = False
731             report.diff = True
732             self.assertEqual(
733                 unstyle(str(report)),
734                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
735                 " would fail to reformat.",
736             )
737
738     def test_report_normal(self) -> None:
739         report = black.Report()
740         out_lines = []
741         err_lines = []
742
743         def out(msg: str, **kwargs: Any) -> None:
744             out_lines.append(msg)
745
746         def err(msg: str, **kwargs: Any) -> None:
747             err_lines.append(msg)
748
749         with patch("black.out", out), patch("black.err", err):
750             report.done(Path("f1"), black.Changed.NO)
751             self.assertEqual(len(out_lines), 0)
752             self.assertEqual(len(err_lines), 0)
753             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
754             self.assertEqual(report.return_code, 0)
755             report.done(Path("f2"), black.Changed.YES)
756             self.assertEqual(len(out_lines), 1)
757             self.assertEqual(len(err_lines), 0)
758             self.assertEqual(out_lines[-1], "reformatted f2")
759             self.assertEqual(
760                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
761             )
762             report.done(Path("f3"), black.Changed.CACHED)
763             self.assertEqual(len(out_lines), 1)
764             self.assertEqual(len(err_lines), 0)
765             self.assertEqual(out_lines[-1], "reformatted f2")
766             self.assertEqual(
767                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
768             )
769             self.assertEqual(report.return_code, 0)
770             report.check = True
771             self.assertEqual(report.return_code, 1)
772             report.check = False
773             report.failed(Path("e1"), "boom")
774             self.assertEqual(len(out_lines), 1)
775             self.assertEqual(len(err_lines), 1)
776             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
777             self.assertEqual(
778                 unstyle(str(report)),
779                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
780                 " reformat.",
781             )
782             self.assertEqual(report.return_code, 123)
783             report.done(Path("f3"), black.Changed.YES)
784             self.assertEqual(len(out_lines), 2)
785             self.assertEqual(len(err_lines), 1)
786             self.assertEqual(out_lines[-1], "reformatted f3")
787             self.assertEqual(
788                 unstyle(str(report)),
789                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
790                 " reformat.",
791             )
792             self.assertEqual(report.return_code, 123)
793             report.failed(Path("e2"), "boom")
794             self.assertEqual(len(out_lines), 2)
795             self.assertEqual(len(err_lines), 2)
796             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
797             self.assertEqual(
798                 unstyle(str(report)),
799                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
800                 " reformat.",
801             )
802             self.assertEqual(report.return_code, 123)
803             report.path_ignored(Path("wat"), "no match")
804             self.assertEqual(len(out_lines), 2)
805             self.assertEqual(len(err_lines), 2)
806             self.assertEqual(
807                 unstyle(str(report)),
808                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
809                 " reformat.",
810             )
811             self.assertEqual(report.return_code, 123)
812             report.done(Path("f4"), black.Changed.NO)
813             self.assertEqual(len(out_lines), 2)
814             self.assertEqual(len(err_lines), 2)
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
818                 " reformat.",
819             )
820             self.assertEqual(report.return_code, 123)
821             report.check = True
822             self.assertEqual(
823                 unstyle(str(report)),
824                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
825                 " would fail to reformat.",
826             )
827             report.check = False
828             report.diff = True
829             self.assertEqual(
830                 unstyle(str(report)),
831                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
832                 " would fail to reformat.",
833             )
834
835     def test_lib2to3_parse(self) -> None:
836         with self.assertRaises(black.InvalidInput):
837             black.lib2to3_parse("invalid syntax")
838
839         straddling = "x + y"
840         black.lib2to3_parse(straddling)
841         black.lib2to3_parse(straddling, {TargetVersion.PY27})
842         black.lib2to3_parse(straddling, {TargetVersion.PY36})
843         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
844
845         py2_only = "print x"
846         black.lib2to3_parse(py2_only)
847         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
848         with self.assertRaises(black.InvalidInput):
849             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
850         with self.assertRaises(black.InvalidInput):
851             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
852
853         py3_only = "exec(x, end=y)"
854         black.lib2to3_parse(py3_only)
855         with self.assertRaises(black.InvalidInput):
856             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
857         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
858         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
859
860     def test_get_features_used_decorator(self) -> None:
861         # Test the feature detection of new decorator syntax
862         # since this makes some test cases of test_get_features_used()
863         # fails if it fails, this is tested first so that a useful case
864         # is identified
865         simples, relaxed = read_data("decorators")
866         # skip explanation comments at the top of the file
867         for simple_test in simples.split("##")[1:]:
868             node = black.lib2to3_parse(simple_test)
869             decorator = str(node.children[0].children[0]).strip()
870             self.assertNotIn(
871                 Feature.RELAXED_DECORATORS,
872                 black.get_features_used(node),
873                 msg=(
874                     f"decorator '{decorator}' follows python<=3.8 syntax"
875                     "but is detected as 3.9+"
876                     # f"The full node is\n{node!r}"
877                 ),
878             )
879         # skip the '# output' comment at the top of the output part
880         for relaxed_test in relaxed.split("##")[1:]:
881             node = black.lib2to3_parse(relaxed_test)
882             decorator = str(node.children[0].children[0]).strip()
883             self.assertIn(
884                 Feature.RELAXED_DECORATORS,
885                 black.get_features_used(node),
886                 msg=(
887                     f"decorator '{decorator}' uses python3.9+ syntax"
888                     "but is detected as python<=3.8"
889                     # f"The full node is\n{node!r}"
890                 ),
891             )
892
893     def test_get_features_used(self) -> None:
894         node = black.lib2to3_parse("def f(*, arg): ...\n")
895         self.assertEqual(black.get_features_used(node), set())
896         node = black.lib2to3_parse("def f(*, arg,): ...\n")
897         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
898         node = black.lib2to3_parse("f(*arg,)\n")
899         self.assertEqual(
900             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
901         )
902         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
903         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
904         node = black.lib2to3_parse("123_456\n")
905         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
906         node = black.lib2to3_parse("123456\n")
907         self.assertEqual(black.get_features_used(node), set())
908         source, expected = read_data("function")
909         node = black.lib2to3_parse(source)
910         expected_features = {
911             Feature.TRAILING_COMMA_IN_CALL,
912             Feature.TRAILING_COMMA_IN_DEF,
913             Feature.F_STRINGS,
914         }
915         self.assertEqual(black.get_features_used(node), expected_features)
916         node = black.lib2to3_parse(expected)
917         self.assertEqual(black.get_features_used(node), expected_features)
918         source, expected = read_data("expression")
919         node = black.lib2to3_parse(source)
920         self.assertEqual(black.get_features_used(node), set())
921         node = black.lib2to3_parse(expected)
922         self.assertEqual(black.get_features_used(node), set())
923
924     def test_get_future_imports(self) -> None:
925         node = black.lib2to3_parse("\n")
926         self.assertEqual(set(), black.get_future_imports(node))
927         node = black.lib2to3_parse("from __future__ import black\n")
928         self.assertEqual({"black"}, black.get_future_imports(node))
929         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
930         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
931         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
932         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
933         node = black.lib2to3_parse(
934             "from __future__ import multiple\nfrom __future__ import imports\n"
935         )
936         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
937         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
938         self.assertEqual({"black"}, black.get_future_imports(node))
939         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
940         self.assertEqual({"black"}, black.get_future_imports(node))
941         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
942         self.assertEqual(set(), black.get_future_imports(node))
943         node = black.lib2to3_parse("from some.module import black\n")
944         self.assertEqual(set(), black.get_future_imports(node))
945         node = black.lib2to3_parse(
946             "from __future__ import unicode_literals as _unicode_literals"
947         )
948         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
949         node = black.lib2to3_parse(
950             "from __future__ import unicode_literals as _lol, print"
951         )
952         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
953
954     def test_debug_visitor(self) -> None:
955         source, _ = read_data("debug_visitor.py")
956         expected, _ = read_data("debug_visitor.out")
957         out_lines = []
958         err_lines = []
959
960         def out(msg: str, **kwargs: Any) -> None:
961             out_lines.append(msg)
962
963         def err(msg: str, **kwargs: Any) -> None:
964             err_lines.append(msg)
965
966         with patch("black.out", out), patch("black.err", err):
967             black.DebugVisitor.show(source)
968         actual = "\n".join(out_lines) + "\n"
969         log_name = ""
970         if expected != actual:
971             log_name = black.dump_to_file(*out_lines)
972         self.assertEqual(
973             expected,
974             actual,
975             f"AST print out is different. Actual version dumped to {log_name}",
976         )
977
978     def test_format_file_contents(self) -> None:
979         empty = ""
980         mode = DEFAULT_MODE
981         with self.assertRaises(black.NothingChanged):
982             black.format_file_contents(empty, mode=mode, fast=False)
983         just_nl = "\n"
984         with self.assertRaises(black.NothingChanged):
985             black.format_file_contents(just_nl, mode=mode, fast=False)
986         same = "j = [1, 2, 3]\n"
987         with self.assertRaises(black.NothingChanged):
988             black.format_file_contents(same, mode=mode, fast=False)
989         different = "j = [1,2,3]"
990         expected = same
991         actual = black.format_file_contents(different, mode=mode, fast=False)
992         self.assertEqual(expected, actual)
993         invalid = "return if you can"
994         with self.assertRaises(black.InvalidInput) as e:
995             black.format_file_contents(invalid, mode=mode, fast=False)
996         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
997
998     def test_endmarker(self) -> None:
999         n = black.lib2to3_parse("\n")
1000         self.assertEqual(n.type, black.syms.file_input)
1001         self.assertEqual(len(n.children), 1)
1002         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1003
1004     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1005     def test_assertFormatEqual(self) -> None:
1006         out_lines = []
1007         err_lines = []
1008
1009         def out(msg: str, **kwargs: Any) -> None:
1010             out_lines.append(msg)
1011
1012         def err(msg: str, **kwargs: Any) -> None:
1013             err_lines.append(msg)
1014
1015         with patch("black.out", out), patch("black.err", err):
1016             with self.assertRaises(AssertionError):
1017                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1018
1019         out_str = "".join(out_lines)
1020         self.assertTrue("Expected tree:" in out_str)
1021         self.assertTrue("Actual tree:" in out_str)
1022         self.assertEqual("".join(err_lines), "")
1023
1024     def test_cache_broken_file(self) -> None:
1025         mode = DEFAULT_MODE
1026         with cache_dir() as workspace:
1027             cache_file = black.get_cache_file(mode)
1028             with cache_file.open("w") as fobj:
1029                 fobj.write("this is not a pickle")
1030             self.assertEqual(black.read_cache(mode), {})
1031             src = (workspace / "test.py").resolve()
1032             with src.open("w") as fobj:
1033                 fobj.write("print('hello')")
1034             self.invokeBlack([str(src)])
1035             cache = black.read_cache(mode)
1036             self.assertIn(str(src), cache)
1037
1038     def test_cache_single_file_already_cached(self) -> None:
1039         mode = DEFAULT_MODE
1040         with cache_dir() as workspace:
1041             src = (workspace / "test.py").resolve()
1042             with src.open("w") as fobj:
1043                 fobj.write("print('hello')")
1044             black.write_cache({}, [src], mode)
1045             self.invokeBlack([str(src)])
1046             with src.open("r") as fobj:
1047                 self.assertEqual(fobj.read(), "print('hello')")
1048
1049     @event_loop()
1050     def test_cache_multiple_files(self) -> None:
1051         mode = DEFAULT_MODE
1052         with cache_dir() as workspace, patch(
1053             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1054         ):
1055             one = (workspace / "one.py").resolve()
1056             with one.open("w") as fobj:
1057                 fobj.write("print('hello')")
1058             two = (workspace / "two.py").resolve()
1059             with two.open("w") as fobj:
1060                 fobj.write("print('hello')")
1061             black.write_cache({}, [one], mode)
1062             self.invokeBlack([str(workspace)])
1063             with one.open("r") as fobj:
1064                 self.assertEqual(fobj.read(), "print('hello')")
1065             with two.open("r") as fobj:
1066                 self.assertEqual(fobj.read(), 'print("hello")\n')
1067             cache = black.read_cache(mode)
1068             self.assertIn(str(one), cache)
1069             self.assertIn(str(two), cache)
1070
1071     def test_no_cache_when_writeback_diff(self) -> None:
1072         mode = DEFAULT_MODE
1073         with cache_dir() as workspace:
1074             src = (workspace / "test.py").resolve()
1075             with src.open("w") as fobj:
1076                 fobj.write("print('hello')")
1077             with patch("black.read_cache") as read_cache, patch(
1078                 "black.write_cache"
1079             ) as write_cache:
1080                 self.invokeBlack([str(src), "--diff"])
1081                 cache_file = black.get_cache_file(mode)
1082                 self.assertFalse(cache_file.exists())
1083                 write_cache.assert_not_called()
1084                 read_cache.assert_not_called()
1085
1086     def test_no_cache_when_writeback_color_diff(self) -> None:
1087         mode = DEFAULT_MODE
1088         with cache_dir() as workspace:
1089             src = (workspace / "test.py").resolve()
1090             with src.open("w") as fobj:
1091                 fobj.write("print('hello')")
1092             with patch("black.read_cache") as read_cache, patch(
1093                 "black.write_cache"
1094             ) as write_cache:
1095                 self.invokeBlack([str(src), "--diff", "--color"])
1096                 cache_file = black.get_cache_file(mode)
1097                 self.assertFalse(cache_file.exists())
1098                 write_cache.assert_not_called()
1099                 read_cache.assert_not_called()
1100
1101     @event_loop()
1102     def test_output_locking_when_writeback_diff(self) -> None:
1103         with cache_dir() as workspace:
1104             for tag in range(0, 4):
1105                 src = (workspace / f"test{tag}.py").resolve()
1106                 with src.open("w") as fobj:
1107                     fobj.write("print('hello')")
1108             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1109                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1110                 # this isn't quite doing what we want, but if it _isn't_
1111                 # called then we cannot be using the lock it provides
1112                 mgr.assert_called()
1113
1114     @event_loop()
1115     def test_output_locking_when_writeback_color_diff(self) -> None:
1116         with cache_dir() as workspace:
1117             for tag in range(0, 4):
1118                 src = (workspace / f"test{tag}.py").resolve()
1119                 with src.open("w") as fobj:
1120                     fobj.write("print('hello')")
1121             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1122                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1123                 # this isn't quite doing what we want, but if it _isn't_
1124                 # called then we cannot be using the lock it provides
1125                 mgr.assert_called()
1126
1127     def test_no_cache_when_stdin(self) -> None:
1128         mode = DEFAULT_MODE
1129         with cache_dir():
1130             result = CliRunner().invoke(
1131                 black.main, ["-"], input=BytesIO(b"print('hello')")
1132             )
1133             self.assertEqual(result.exit_code, 0)
1134             cache_file = black.get_cache_file(mode)
1135             self.assertFalse(cache_file.exists())
1136
1137     def test_read_cache_no_cachefile(self) -> None:
1138         mode = DEFAULT_MODE
1139         with cache_dir():
1140             self.assertEqual(black.read_cache(mode), {})
1141
1142     def test_write_cache_read_cache(self) -> None:
1143         mode = DEFAULT_MODE
1144         with cache_dir() as workspace:
1145             src = (workspace / "test.py").resolve()
1146             src.touch()
1147             black.write_cache({}, [src], mode)
1148             cache = black.read_cache(mode)
1149             self.assertIn(str(src), cache)
1150             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1151
1152     def test_filter_cached(self) -> None:
1153         with TemporaryDirectory() as workspace:
1154             path = Path(workspace)
1155             uncached = (path / "uncached").resolve()
1156             cached = (path / "cached").resolve()
1157             cached_but_changed = (path / "changed").resolve()
1158             uncached.touch()
1159             cached.touch()
1160             cached_but_changed.touch()
1161             cache = {
1162                 str(cached): black.get_cache_info(cached),
1163                 str(cached_but_changed): (0.0, 0),
1164             }
1165             todo, done = black.filter_cached(
1166                 cache, {uncached, cached, cached_but_changed}
1167             )
1168             self.assertEqual(todo, {uncached, cached_but_changed})
1169             self.assertEqual(done, {cached})
1170
1171     def test_write_cache_creates_directory_if_needed(self) -> None:
1172         mode = DEFAULT_MODE
1173         with cache_dir(exists=False) as workspace:
1174             self.assertFalse(workspace.exists())
1175             black.write_cache({}, [], mode)
1176             self.assertTrue(workspace.exists())
1177
1178     @event_loop()
1179     def test_failed_formatting_does_not_get_cached(self) -> None:
1180         mode = DEFAULT_MODE
1181         with cache_dir() as workspace, patch(
1182             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1183         ):
1184             failing = (workspace / "failing.py").resolve()
1185             with failing.open("w") as fobj:
1186                 fobj.write("not actually python")
1187             clean = (workspace / "clean.py").resolve()
1188             with clean.open("w") as fobj:
1189                 fobj.write('print("hello")\n')
1190             self.invokeBlack([str(workspace)], exit_code=123)
1191             cache = black.read_cache(mode)
1192             self.assertNotIn(str(failing), cache)
1193             self.assertIn(str(clean), cache)
1194
1195     def test_write_cache_write_fail(self) -> None:
1196         mode = DEFAULT_MODE
1197         with cache_dir(), patch.object(Path, "open") as mock:
1198             mock.side_effect = OSError
1199             black.write_cache({}, [], mode)
1200
1201     @event_loop()
1202     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1203     def test_works_in_mono_process_only_environment(self) -> None:
1204         with cache_dir() as workspace:
1205             for f in [
1206                 (workspace / "one.py").resolve(),
1207                 (workspace / "two.py").resolve(),
1208             ]:
1209                 f.write_text('print("hello")\n')
1210             self.invokeBlack([str(workspace)])
1211
1212     @event_loop()
1213     def test_check_diff_use_together(self) -> None:
1214         with cache_dir():
1215             # Files which will be reformatted.
1216             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1217             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1218             # Files which will not be reformatted.
1219             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1220             self.invokeBlack([str(src2), "--diff", "--check"])
1221             # Multi file command.
1222             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1223
1224     def test_no_files(self) -> None:
1225         with cache_dir():
1226             # Without an argument, black exits with error code 0.
1227             self.invokeBlack([])
1228
1229     def test_broken_symlink(self) -> None:
1230         with cache_dir() as workspace:
1231             symlink = workspace / "broken_link.py"
1232             try:
1233                 symlink.symlink_to("nonexistent.py")
1234             except OSError as e:
1235                 self.skipTest(f"Can't create symlinks: {e}")
1236             self.invokeBlack([str(workspace.resolve())])
1237
1238     def test_read_cache_line_lengths(self) -> None:
1239         mode = DEFAULT_MODE
1240         short_mode = replace(DEFAULT_MODE, line_length=1)
1241         with cache_dir() as workspace:
1242             path = (workspace / "file.py").resolve()
1243             path.touch()
1244             black.write_cache({}, [path], mode)
1245             one = black.read_cache(mode)
1246             self.assertIn(str(path), one)
1247             two = black.read_cache(short_mode)
1248             self.assertNotIn(str(path), two)
1249
1250     def test_single_file_force_pyi(self) -> None:
1251         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1252         contents, expected = read_data("force_pyi")
1253         with cache_dir() as workspace:
1254             path = (workspace / "file.py").resolve()
1255             with open(path, "w") as fh:
1256                 fh.write(contents)
1257             self.invokeBlack([str(path), "--pyi"])
1258             with open(path, "r") as fh:
1259                 actual = fh.read()
1260             # verify cache with --pyi is separate
1261             pyi_cache = black.read_cache(pyi_mode)
1262             self.assertIn(str(path), pyi_cache)
1263             normal_cache = black.read_cache(DEFAULT_MODE)
1264             self.assertNotIn(str(path), normal_cache)
1265         self.assertFormatEqual(expected, actual)
1266         black.assert_equivalent(contents, actual)
1267         black.assert_stable(contents, actual, pyi_mode)
1268
1269     @event_loop()
1270     def test_multi_file_force_pyi(self) -> None:
1271         reg_mode = DEFAULT_MODE
1272         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1273         contents, expected = read_data("force_pyi")
1274         with cache_dir() as workspace:
1275             paths = [
1276                 (workspace / "file1.py").resolve(),
1277                 (workspace / "file2.py").resolve(),
1278             ]
1279             for path in paths:
1280                 with open(path, "w") as fh:
1281                     fh.write(contents)
1282             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1283             for path in paths:
1284                 with open(path, "r") as fh:
1285                     actual = fh.read()
1286                 self.assertEqual(actual, expected)
1287             # verify cache with --pyi is separate
1288             pyi_cache = black.read_cache(pyi_mode)
1289             normal_cache = black.read_cache(reg_mode)
1290             for path in paths:
1291                 self.assertIn(str(path), pyi_cache)
1292                 self.assertNotIn(str(path), normal_cache)
1293
1294     def test_pipe_force_pyi(self) -> None:
1295         source, expected = read_data("force_pyi")
1296         result = CliRunner().invoke(
1297             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1298         )
1299         self.assertEqual(result.exit_code, 0)
1300         actual = result.output
1301         self.assertFormatEqual(actual, expected)
1302
1303     def test_single_file_force_py36(self) -> None:
1304         reg_mode = DEFAULT_MODE
1305         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1306         source, expected = read_data("force_py36")
1307         with cache_dir() as workspace:
1308             path = (workspace / "file.py").resolve()
1309             with open(path, "w") as fh:
1310                 fh.write(source)
1311             self.invokeBlack([str(path), *PY36_ARGS])
1312             with open(path, "r") as fh:
1313                 actual = fh.read()
1314             # verify cache with --target-version is separate
1315             py36_cache = black.read_cache(py36_mode)
1316             self.assertIn(str(path), py36_cache)
1317             normal_cache = black.read_cache(reg_mode)
1318             self.assertNotIn(str(path), normal_cache)
1319         self.assertEqual(actual, expected)
1320
1321     @event_loop()
1322     def test_multi_file_force_py36(self) -> None:
1323         reg_mode = DEFAULT_MODE
1324         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1325         source, expected = read_data("force_py36")
1326         with cache_dir() as workspace:
1327             paths = [
1328                 (workspace / "file1.py").resolve(),
1329                 (workspace / "file2.py").resolve(),
1330             ]
1331             for path in paths:
1332                 with open(path, "w") as fh:
1333                     fh.write(source)
1334             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1335             for path in paths:
1336                 with open(path, "r") as fh:
1337                     actual = fh.read()
1338                 self.assertEqual(actual, expected)
1339             # verify cache with --target-version is separate
1340             pyi_cache = black.read_cache(py36_mode)
1341             normal_cache = black.read_cache(reg_mode)
1342             for path in paths:
1343                 self.assertIn(str(path), pyi_cache)
1344                 self.assertNotIn(str(path), normal_cache)
1345
1346     def test_pipe_force_py36(self) -> None:
1347         source, expected = read_data("force_py36")
1348         result = CliRunner().invoke(
1349             black.main,
1350             ["-", "-q", "--target-version=py36"],
1351             input=BytesIO(source.encode("utf8")),
1352         )
1353         self.assertEqual(result.exit_code, 0)
1354         actual = result.output
1355         self.assertFormatEqual(actual, expected)
1356
1357     def test_include_exclude(self) -> None:
1358         path = THIS_DIR / "data" / "include_exclude_tests"
1359         include = re.compile(r"\.pyi?$")
1360         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1361         report = black.Report()
1362         gitignore = PathSpec.from_lines("gitwildmatch", [])
1363         sources: List[Path] = []
1364         expected = [
1365             Path(path / "b/dont_exclude/a.py"),
1366             Path(path / "b/dont_exclude/a.pyi"),
1367         ]
1368         this_abs = THIS_DIR.resolve()
1369         sources.extend(
1370             black.gen_python_files(
1371                 path.iterdir(),
1372                 this_abs,
1373                 include,
1374                 exclude,
1375                 None,
1376                 None,
1377                 report,
1378                 gitignore,
1379             )
1380         )
1381         self.assertEqual(sorted(expected), sorted(sources))
1382
1383     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1384     def test_exclude_for_issue_1572(self) -> None:
1385         # Exclude shouldn't touch files that were explicitly given to Black through the
1386         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1387         # https://github.com/psf/black/issues/1572
1388         path = THIS_DIR / "data" / "include_exclude_tests"
1389         include = ""
1390         exclude = r"/exclude/|a\.py"
1391         src = str(path / "b/exclude/a.py")
1392         report = black.Report()
1393         expected = [Path(path / "b/exclude/a.py")]
1394         sources = list(
1395             black.get_sources(
1396                 ctx=FakeContext(),
1397                 src=(src,),
1398                 quiet=True,
1399                 verbose=False,
1400                 include=re.compile(include),
1401                 exclude=re.compile(exclude),
1402                 extend_exclude=None,
1403                 force_exclude=None,
1404                 report=report,
1405                 stdin_filename=None,
1406             )
1407         )
1408         self.assertEqual(sorted(expected), sorted(sources))
1409
1410     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1411     def test_get_sources_with_stdin(self) -> None:
1412         include = ""
1413         exclude = r"/exclude/|a\.py"
1414         src = "-"
1415         report = black.Report()
1416         expected = [Path("-")]
1417         sources = list(
1418             black.get_sources(
1419                 ctx=FakeContext(),
1420                 src=(src,),
1421                 quiet=True,
1422                 verbose=False,
1423                 include=re.compile(include),
1424                 exclude=re.compile(exclude),
1425                 extend_exclude=None,
1426                 force_exclude=None,
1427                 report=report,
1428                 stdin_filename=None,
1429             )
1430         )
1431         self.assertEqual(sorted(expected), sorted(sources))
1432
1433     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1434     def test_get_sources_with_stdin_filename(self) -> None:
1435         include = ""
1436         exclude = r"/exclude/|a\.py"
1437         src = "-"
1438         report = black.Report()
1439         stdin_filename = str(THIS_DIR / "data/collections.py")
1440         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1441         sources = list(
1442             black.get_sources(
1443                 ctx=FakeContext(),
1444                 src=(src,),
1445                 quiet=True,
1446                 verbose=False,
1447                 include=re.compile(include),
1448                 exclude=re.compile(exclude),
1449                 extend_exclude=None,
1450                 force_exclude=None,
1451                 report=report,
1452                 stdin_filename=stdin_filename,
1453             )
1454         )
1455         self.assertEqual(sorted(expected), sorted(sources))
1456
1457     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1458     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1459         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1460         # file being passed directly. This is the same as
1461         # test_exclude_for_issue_1572
1462         path = THIS_DIR / "data" / "include_exclude_tests"
1463         include = ""
1464         exclude = r"/exclude/|a\.py"
1465         src = "-"
1466         report = black.Report()
1467         stdin_filename = str(path / "b/exclude/a.py")
1468         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1469         sources = list(
1470             black.get_sources(
1471                 ctx=FakeContext(),
1472                 src=(src,),
1473                 quiet=True,
1474                 verbose=False,
1475                 include=re.compile(include),
1476                 exclude=re.compile(exclude),
1477                 extend_exclude=None,
1478                 force_exclude=None,
1479                 report=report,
1480                 stdin_filename=stdin_filename,
1481             )
1482         )
1483         self.assertEqual(sorted(expected), sorted(sources))
1484
1485     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1486     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1487         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1488         # file being passed directly. This is the same as
1489         # test_exclude_for_issue_1572
1490         path = THIS_DIR / "data" / "include_exclude_tests"
1491         include = ""
1492         extend_exclude = r"/exclude/|a\.py"
1493         src = "-"
1494         report = black.Report()
1495         stdin_filename = str(path / "b/exclude/a.py")
1496         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1497         sources = list(
1498             black.get_sources(
1499                 ctx=FakeContext(),
1500                 src=(src,),
1501                 quiet=True,
1502                 verbose=False,
1503                 include=re.compile(include),
1504                 exclude=re.compile(""),
1505                 extend_exclude=re.compile(extend_exclude),
1506                 force_exclude=None,
1507                 report=report,
1508                 stdin_filename=stdin_filename,
1509             )
1510         )
1511         self.assertEqual(sorted(expected), sorted(sources))
1512
1513     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1514     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1515         # Force exclude should exclude the file when passing it through
1516         # stdin_filename
1517         path = THIS_DIR / "data" / "include_exclude_tests"
1518         include = ""
1519         force_exclude = r"/exclude/|a\.py"
1520         src = "-"
1521         report = black.Report()
1522         stdin_filename = str(path / "b/exclude/a.py")
1523         sources = list(
1524             black.get_sources(
1525                 ctx=FakeContext(),
1526                 src=(src,),
1527                 quiet=True,
1528                 verbose=False,
1529                 include=re.compile(include),
1530                 exclude=re.compile(""),
1531                 extend_exclude=None,
1532                 force_exclude=re.compile(force_exclude),
1533                 report=report,
1534                 stdin_filename=stdin_filename,
1535             )
1536         )
1537         self.assertEqual([], sorted(sources))
1538
1539     def test_reformat_one_with_stdin(self) -> None:
1540         with patch(
1541             "black.format_stdin_to_stdout",
1542             return_value=lambda *args, **kwargs: black.Changed.YES,
1543         ) as fsts:
1544             report = MagicMock()
1545             path = Path("-")
1546             black.reformat_one(
1547                 path,
1548                 fast=True,
1549                 write_back=black.WriteBack.YES,
1550                 mode=DEFAULT_MODE,
1551                 report=report,
1552             )
1553             fsts.assert_called_once()
1554             report.done.assert_called_with(path, black.Changed.YES)
1555
1556     def test_reformat_one_with_stdin_filename(self) -> None:
1557         with patch(
1558             "black.format_stdin_to_stdout",
1559             return_value=lambda *args, **kwargs: black.Changed.YES,
1560         ) as fsts:
1561             report = MagicMock()
1562             p = "foo.py"
1563             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1564             expected = Path(p)
1565             black.reformat_one(
1566                 path,
1567                 fast=True,
1568                 write_back=black.WriteBack.YES,
1569                 mode=DEFAULT_MODE,
1570                 report=report,
1571             )
1572             fsts.assert_called_once()
1573             # __BLACK_STDIN_FILENAME__ should have been striped
1574             report.done.assert_called_with(expected, black.Changed.YES)
1575
1576     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1577         with patch(
1578             "black.format_stdin_to_stdout",
1579             return_value=lambda *args, **kwargs: black.Changed.YES,
1580         ) as fsts:
1581             report = MagicMock()
1582             # Even with an existing file, since we are forcing stdin, black
1583             # should output to stdout and not modify the file inplace
1584             p = Path(str(THIS_DIR / "data/collections.py"))
1585             # Make sure is_file actually returns True
1586             self.assertTrue(p.is_file())
1587             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1588             expected = Path(p)
1589             black.reformat_one(
1590                 path,
1591                 fast=True,
1592                 write_back=black.WriteBack.YES,
1593                 mode=DEFAULT_MODE,
1594                 report=report,
1595             )
1596             fsts.assert_called_once()
1597             # __BLACK_STDIN_FILENAME__ should have been striped
1598             report.done.assert_called_with(expected, black.Changed.YES)
1599
1600     def test_gitignore_exclude(self) -> None:
1601         path = THIS_DIR / "data" / "include_exclude_tests"
1602         include = re.compile(r"\.pyi?$")
1603         exclude = re.compile(r"")
1604         report = black.Report()
1605         gitignore = PathSpec.from_lines(
1606             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1607         )
1608         sources: List[Path] = []
1609         expected = [
1610             Path(path / "b/dont_exclude/a.py"),
1611             Path(path / "b/dont_exclude/a.pyi"),
1612         ]
1613         this_abs = THIS_DIR.resolve()
1614         sources.extend(
1615             black.gen_python_files(
1616                 path.iterdir(),
1617                 this_abs,
1618                 include,
1619                 exclude,
1620                 None,
1621                 None,
1622                 report,
1623                 gitignore,
1624             )
1625         )
1626         self.assertEqual(sorted(expected), sorted(sources))
1627
1628     def test_empty_include(self) -> None:
1629         path = THIS_DIR / "data" / "include_exclude_tests"
1630         report = black.Report()
1631         gitignore = PathSpec.from_lines("gitwildmatch", [])
1632         empty = re.compile(r"")
1633         sources: List[Path] = []
1634         expected = [
1635             Path(path / "b/exclude/a.pie"),
1636             Path(path / "b/exclude/a.py"),
1637             Path(path / "b/exclude/a.pyi"),
1638             Path(path / "b/dont_exclude/a.pie"),
1639             Path(path / "b/dont_exclude/a.py"),
1640             Path(path / "b/dont_exclude/a.pyi"),
1641             Path(path / "b/.definitely_exclude/a.pie"),
1642             Path(path / "b/.definitely_exclude/a.py"),
1643             Path(path / "b/.definitely_exclude/a.pyi"),
1644         ]
1645         this_abs = THIS_DIR.resolve()
1646         sources.extend(
1647             black.gen_python_files(
1648                 path.iterdir(),
1649                 this_abs,
1650                 empty,
1651                 re.compile(black.DEFAULT_EXCLUDES),
1652                 None,
1653                 None,
1654                 report,
1655                 gitignore,
1656             )
1657         )
1658         self.assertEqual(sorted(expected), sorted(sources))
1659
1660     def test_extend_exclude(self) -> None:
1661         path = THIS_DIR / "data" / "include_exclude_tests"
1662         report = black.Report()
1663         gitignore = PathSpec.from_lines("gitwildmatch", [])
1664         sources: List[Path] = []
1665         expected = [
1666             Path(path / "b/exclude/a.py"),
1667             Path(path / "b/dont_exclude/a.py"),
1668         ]
1669         this_abs = THIS_DIR.resolve()
1670         sources.extend(
1671             black.gen_python_files(
1672                 path.iterdir(),
1673                 this_abs,
1674                 re.compile(black.DEFAULT_INCLUDES),
1675                 re.compile(r"\.pyi$"),
1676                 re.compile(r"\.definitely_exclude"),
1677                 None,
1678                 report,
1679                 gitignore,
1680             )
1681         )
1682         self.assertEqual(sorted(expected), sorted(sources))
1683
1684     def test_invalid_cli_regex(self) -> None:
1685         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1686             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1687
1688     def test_preserves_line_endings(self) -> None:
1689         with TemporaryDirectory() as workspace:
1690             test_file = Path(workspace) / "test.py"
1691             for nl in ["\n", "\r\n"]:
1692                 contents = nl.join(["def f(  ):", "    pass"])
1693                 test_file.write_bytes(contents.encode())
1694                 ff(test_file, write_back=black.WriteBack.YES)
1695                 updated_contents: bytes = test_file.read_bytes()
1696                 self.assertIn(nl.encode(), updated_contents)
1697                 if nl == "\n":
1698                     self.assertNotIn(b"\r\n", updated_contents)
1699
1700     def test_preserves_line_endings_via_stdin(self) -> None:
1701         for nl in ["\n", "\r\n"]:
1702             contents = nl.join(["def f(  ):", "    pass"])
1703             runner = BlackRunner()
1704             result = runner.invoke(
1705                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1706             )
1707             self.assertEqual(result.exit_code, 0)
1708             output = runner.stdout_bytes
1709             self.assertIn(nl.encode("utf8"), output)
1710             if nl == "\n":
1711                 self.assertNotIn(b"\r\n", output)
1712
1713     def test_assert_equivalent_different_asts(self) -> None:
1714         with self.assertRaises(AssertionError):
1715             black.assert_equivalent("{}", "None")
1716
1717     def test_symlink_out_of_root_directory(self) -> None:
1718         path = MagicMock()
1719         root = THIS_DIR.resolve()
1720         child = MagicMock()
1721         include = re.compile(black.DEFAULT_INCLUDES)
1722         exclude = re.compile(black.DEFAULT_EXCLUDES)
1723         report = black.Report()
1724         gitignore = PathSpec.from_lines("gitwildmatch", [])
1725         # `child` should behave like a symlink which resolved path is clearly
1726         # outside of the `root` directory.
1727         path.iterdir.return_value = [child]
1728         child.resolve.return_value = Path("/a/b/c")
1729         child.as_posix.return_value = "/a/b/c"
1730         child.is_symlink.return_value = True
1731         try:
1732             list(
1733                 black.gen_python_files(
1734                     path.iterdir(),
1735                     root,
1736                     include,
1737                     exclude,
1738                     None,
1739                     None,
1740                     report,
1741                     gitignore,
1742                 )
1743             )
1744         except ValueError as ve:
1745             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1746         path.iterdir.assert_called_once()
1747         child.resolve.assert_called_once()
1748         child.is_symlink.assert_called_once()
1749         # `child` should behave like a strange file which resolved path is clearly
1750         # outside of the `root` directory.
1751         child.is_symlink.return_value = False
1752         with self.assertRaises(ValueError):
1753             list(
1754                 black.gen_python_files(
1755                     path.iterdir(),
1756                     root,
1757                     include,
1758                     exclude,
1759                     None,
1760                     None,
1761                     report,
1762                     gitignore,
1763                 )
1764             )
1765         path.iterdir.assert_called()
1766         self.assertEqual(path.iterdir.call_count, 2)
1767         child.resolve.assert_called()
1768         self.assertEqual(child.resolve.call_count, 2)
1769         child.is_symlink.assert_called()
1770         self.assertEqual(child.is_symlink.call_count, 2)
1771
1772     def test_shhh_click(self) -> None:
1773         try:
1774             from click import _unicodefun  # type: ignore
1775         except ModuleNotFoundError:
1776             self.skipTest("Incompatible Click version")
1777         if not hasattr(_unicodefun, "_verify_python3_env"):
1778             self.skipTest("Incompatible Click version")
1779         # First, let's see if Click is crashing with a preferred ASCII charset.
1780         with patch("locale.getpreferredencoding") as gpe:
1781             gpe.return_value = "ASCII"
1782             with self.assertRaises(RuntimeError):
1783                 _unicodefun._verify_python3_env()
1784         # Now, let's silence Click...
1785         black.patch_click()
1786         # ...and confirm it's silent.
1787         with patch("locale.getpreferredencoding") as gpe:
1788             gpe.return_value = "ASCII"
1789             try:
1790                 _unicodefun._verify_python3_env()
1791             except RuntimeError as re:
1792                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1793
1794     def test_root_logger_not_used_directly(self) -> None:
1795         def fail(*args: Any, **kwargs: Any) -> None:
1796             self.fail("Record created with root logger")
1797
1798         with patch.multiple(
1799             logging.root,
1800             debug=fail,
1801             info=fail,
1802             warning=fail,
1803             error=fail,
1804             critical=fail,
1805             log=fail,
1806         ):
1807             ff(THIS_FILE)
1808
1809     def test_invalid_config_return_code(self) -> None:
1810         tmp_file = Path(black.dump_to_file())
1811         try:
1812             tmp_config = Path(black.dump_to_file())
1813             tmp_config.unlink()
1814             args = ["--config", str(tmp_config), str(tmp_file)]
1815             self.invokeBlack(args, exit_code=2, ignore_config=False)
1816         finally:
1817             tmp_file.unlink()
1818
1819     def test_parse_pyproject_toml(self) -> None:
1820         test_toml_file = THIS_DIR / "test.toml"
1821         config = black.parse_pyproject_toml(str(test_toml_file))
1822         self.assertEqual(config["verbose"], 1)
1823         self.assertEqual(config["check"], "no")
1824         self.assertEqual(config["diff"], "y")
1825         self.assertEqual(config["color"], True)
1826         self.assertEqual(config["line_length"], 79)
1827         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1828         self.assertEqual(config["exclude"], r"\.pyi?$")
1829         self.assertEqual(config["include"], r"\.py?$")
1830
1831     def test_read_pyproject_toml(self) -> None:
1832         test_toml_file = THIS_DIR / "test.toml"
1833         fake_ctx = FakeContext()
1834         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1835         config = fake_ctx.default_map
1836         self.assertEqual(config["verbose"], "1")
1837         self.assertEqual(config["check"], "no")
1838         self.assertEqual(config["diff"], "y")
1839         self.assertEqual(config["color"], "True")
1840         self.assertEqual(config["line_length"], "79")
1841         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1842         self.assertEqual(config["exclude"], r"\.pyi?$")
1843         self.assertEqual(config["include"], r"\.py?$")
1844
1845     def test_find_project_root(self) -> None:
1846         with TemporaryDirectory() as workspace:
1847             root = Path(workspace)
1848             test_dir = root / "test"
1849             test_dir.mkdir()
1850
1851             src_dir = root / "src"
1852             src_dir.mkdir()
1853
1854             root_pyproject = root / "pyproject.toml"
1855             root_pyproject.touch()
1856             src_pyproject = src_dir / "pyproject.toml"
1857             src_pyproject.touch()
1858             src_python = src_dir / "foo.py"
1859             src_python.touch()
1860
1861             self.assertEqual(
1862                 black.find_project_root((src_dir, test_dir)), root.resolve()
1863             )
1864             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1865             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1866
1867     @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
1868     def test_find_user_pyproject_toml_linux(self) -> None:
1869         if system() == "Windows":
1870             return
1871
1872         # Test if XDG_CONFIG_HOME is checked
1873         with TemporaryDirectory() as workspace:
1874             tmp_user_config = Path(workspace) / "black"
1875             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1876                 self.assertEqual(
1877                     black.find_user_pyproject_toml(), tmp_user_config.resolve()
1878                 )
1879
1880         # Test fallback for XDG_CONFIG_HOME
1881         with patch.dict("os.environ"):
1882             os.environ.pop("XDG_CONFIG_HOME", None)
1883             fallback_user_config = Path("~/.config").expanduser() / "black"
1884             self.assertEqual(
1885                 black.find_user_pyproject_toml(), fallback_user_config.resolve()
1886             )
1887
1888     def test_find_user_pyproject_toml_windows(self) -> None:
1889         if system() != "Windows":
1890             return
1891
1892         user_config_path = Path.home() / ".black"
1893         self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
1894
1895     def test_bpo_33660_workaround(self) -> None:
1896         if system() == "Windows":
1897             return
1898
1899         # https://bugs.python.org/issue33660
1900
1901         old_cwd = Path.cwd()
1902         try:
1903             root = Path("/")
1904             os.chdir(str(root))
1905             path = Path("workspace") / "project"
1906             report = black.Report(verbose=True)
1907             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1908             self.assertEqual(normalized_path, "workspace/project")
1909         finally:
1910             os.chdir(str(old_cwd))
1911
1912     def test_newline_comment_interaction(self) -> None:
1913         source = "class A:\\\r\n# type: ignore\n pass\n"
1914         output = black.format_str(source, mode=DEFAULT_MODE)
1915         black.assert_stable(source, output, mode=DEFAULT_MODE)
1916
1917     def test_bpo_2142_workaround(self) -> None:
1918
1919         # https://bugs.python.org/issue2142
1920
1921         source, _ = read_data("missing_final_newline.py")
1922         # read_data adds a trailing newline
1923         source = source.rstrip()
1924         expected, _ = read_data("missing_final_newline.diff")
1925         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1926         diff_header = re.compile(
1927             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1928             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1929         )
1930         try:
1931             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1932             self.assertEqual(result.exit_code, 0)
1933         finally:
1934             os.unlink(tmp_file)
1935         actual = result.output
1936         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1937         self.assertEqual(actual, expected)
1938
1939     def test_docstring_reformat_for_py27(self) -> None:
1940         """
1941         Check that stripping trailing whitespace from Python 2 docstrings
1942         doesn't trigger a "not equivalent to source" error
1943         """
1944         source = (
1945             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1946         )
1947         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1948
1949         result = CliRunner().invoke(
1950             black.main,
1951             ["-", "-q", "--target-version=py27"],
1952             input=BytesIO(source),
1953         )
1954
1955         self.assertEqual(result.exit_code, 0)
1956         actual = result.output
1957         self.assertFormatEqual(actual, expected)
1958
1959
1960 with open(black.__file__, "r", encoding="utf-8") as _bf:
1961     black_source_lines = _bf.readlines()
1962
1963
1964 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
1965     """Show function calls `from black/__init__.py` as they happen.
1966
1967     Register this with `sys.settrace()` in a test you're debugging.
1968     """
1969     if event != "call":
1970         return tracefunc
1971
1972     stack = len(inspect.stack()) - 19
1973     stack *= 2
1974     filename = frame.f_code.co_filename
1975     lineno = frame.f_lineno
1976     func_sig_lineno = lineno - 1
1977     funcname = black_source_lines[func_sig_lineno].strip()
1978     while funcname.startswith("@"):
1979         func_sig_lineno += 1
1980         funcname = black_source_lines[func_sig_lineno].strip()
1981     if "black/__init__.py" in filename:
1982         print(f"{' ' * stack}{lineno}:{funcname}")
1983     return tracefunc
1984
1985
1986 if __name__ == "__main__":
1987     unittest.main(module="test_black")