]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

201edfa2de522b17b84a153333a969100c933e87
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 from pathspec import PathSpec
38
39 # Import other test classes
40 from tests.util import (
41     THIS_DIR,
42     read_data,
43     DETERMINISTIC_HEADER,
44     BlackBaseTestCase,
45     DEFAULT_MODE,
46     fs,
47     ff,
48     dump_to_stderr,
49 )
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 THIS_FILE = Path(__file__)
54 PY36_VERSIONS = {
55     TargetVersion.PY36,
56     TargetVersion.PY37,
57     TargetVersion.PY38,
58     TargetVersion.PY39,
59 }
60 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
61 T = TypeVar("T")
62 R = TypeVar("R")
63
64
65 @contextmanager
66 def cache_dir(exists: bool = True) -> Iterator[Path]:
67     with TemporaryDirectory() as workspace:
68         cache_dir = Path(workspace)
69         if not exists:
70             cache_dir = cache_dir / "new"
71         with patch("black.CACHE_DIR", cache_dir):
72             yield cache_dir
73
74
75 @contextmanager
76 def event_loop() -> Iterator[None]:
77     policy = asyncio.get_event_loop_policy()
78     loop = policy.new_event_loop()
79     asyncio.set_event_loop(loop)
80     try:
81         yield
82
83     finally:
84         loop.close()
85
86
87 class FakeContext(click.Context):
88     """A fake click Context for when calling functions that need it."""
89
90     def __init__(self) -> None:
91         self.default_map: Dict[str, Any] = {}
92
93
94 class FakeParameter(click.Parameter):
95     """A fake click Parameter for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         pass
99
100
101 class BlackRunner(CliRunner):
102     """Modify CliRunner so that stderr is not merged with stdout.
103
104     This is a hack that can be removed once we depend on Click 7.x"""
105
106     def __init__(self) -> None:
107         self.stderrbuf = BytesIO()
108         self.stdoutbuf = BytesIO()
109         self.stdout_bytes = b""
110         self.stderr_bytes = b""
111         super().__init__()
112
113     @contextmanager
114     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
115         with super().isolation(*args, **kwargs) as output:
116             try:
117                 hold_stderr = sys.stderr
118                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
119                 yield output
120             finally:
121                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
122                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
123                 sys.stderr = hold_stderr
124
125
126 class BlackTestCase(BlackBaseTestCase):
127     def invokeBlack(
128         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
129     ) -> None:
130         runner = BlackRunner()
131         if ignore_config:
132             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
133         result = runner.invoke(black.main, args)
134         self.assertEqual(
135             result.exit_code,
136             exit_code,
137             msg=(
138                 f"Failed with args: {args}\n"
139                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
140                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
141                 f"exception: {result.exception}"
142             ),
143         )
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_empty(self) -> None:
147         source = expected = ""
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, DEFAULT_MODE)
152
153     def test_empty_ff(self) -> None:
154         expected = ""
155         tmp_file = Path(black.dump_to_file())
156         try:
157             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
158             with open(tmp_file, encoding="utf8") as f:
159                 actual = f.read()
160         finally:
161             os.unlink(tmp_file)
162         self.assertFormatEqual(expected, actual)
163
164     def test_piping(self) -> None:
165         source, expected = read_data("src/black/__init__", data=False)
166         result = BlackRunner().invoke(
167             black.main,
168             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         black.assert_equivalent(source, result.output)
174         black.assert_stable(source, result.output, DEFAULT_MODE)
175
176     def test_piping_diff(self) -> None:
177         diff_header = re.compile(
178             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
179             r"\+\d\d\d\d"
180         )
181         source, _ = read_data("expression.py")
182         expected, _ = read_data("expression.diff")
183         config = THIS_DIR / "data" / "empty_pyproject.toml"
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={config}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         config = THIS_DIR / "data" / "empty_pyproject.toml"
202         args = [
203             "-",
204             "--fast",
205             f"--line-length={black.DEFAULT_LINE_LENGTH}",
206             "--diff",
207             "--color",
208             f"--config={config}",
209         ]
210         result = BlackRunner().invoke(
211             black.main, args, input=BytesIO(source.encode("utf8"))
212         )
213         actual = result.output
214         # Again, the contents are checked in a different test, so only look for colors.
215         self.assertIn("\033[1;37m", actual)
216         self.assertIn("\033[36m", actual)
217         self.assertIn("\033[32m", actual)
218         self.assertIn("\033[31m", actual)
219         self.assertIn("\033[0m", actual)
220
221     @patch("black.dump_to_file", dump_to_stderr)
222     def _test_wip(self) -> None:
223         source, expected = read_data("wip")
224         sys.settrace(tracefunc)
225         mode = replace(
226             DEFAULT_MODE,
227             experimental_string_processing=False,
228             target_versions={black.TargetVersion.PY38},
229         )
230         actual = fs(source, mode=mode)
231         sys.settrace(None)
232         self.assertFormatEqual(expected, actual)
233         black.assert_equivalent(source, actual)
234         black.assert_stable(source, actual, black.FileMode())
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability1(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens1")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability2(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens2")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @unittest.expectedFailure
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability3(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens3")
254         actual = fs(source)
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_pep_572(self) -> None:
259         source, expected = read_data("pep_572")
260         actual = fs(source)
261         self.assertFormatEqual(expected, actual)
262         black.assert_stable(source, actual, DEFAULT_MODE)
263         if sys.version_info >= (3, 8):
264             black.assert_equivalent(source, actual)
265
266     @patch("black.dump_to_file", dump_to_stderr)
267     def test_pep_572_remove_parens(self) -> None:
268         source, expected = read_data("pep_572_remove_parens")
269         actual = fs(source)
270         self.assertFormatEqual(expected, actual)
271         black.assert_stable(source, actual, DEFAULT_MODE)
272         if sys.version_info >= (3, 8):
273             black.assert_equivalent(source, actual)
274
275     def test_pep_572_version_detection(self) -> None:
276         source, _ = read_data("pep_572")
277         root = black.lib2to3_parse(source)
278         features = black.get_features_used(root)
279         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
280         versions = black.detect_target_versions(root)
281         self.assertIn(black.TargetVersion.PY38, versions)
282
283     def test_expression_ff(self) -> None:
284         source, expected = read_data("expression")
285         tmp_file = Path(black.dump_to_file(source))
286         try:
287             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
288             with open(tmp_file, encoding="utf8") as f:
289                 actual = f.read()
290         finally:
291             os.unlink(tmp_file)
292         self.assertFormatEqual(expected, actual)
293         with patch("black.dump_to_file", dump_to_stderr):
294             black.assert_equivalent(source, actual)
295             black.assert_stable(source, actual, DEFAULT_MODE)
296
297     def test_expression_diff(self) -> None:
298         source, _ = read_data("expression.py")
299         config = THIS_DIR / "data" / "empty_pyproject.toml"
300         expected, _ = read_data("expression.diff")
301         tmp_file = Path(black.dump_to_file(source))
302         diff_header = re.compile(
303             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
304             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
305         )
306         try:
307             result = BlackRunner().invoke(
308                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
309             )
310             self.assertEqual(result.exit_code, 0)
311         finally:
312             os.unlink(tmp_file)
313         actual = result.output
314         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
315         if expected != actual:
316             dump = black.dump_to_file(actual)
317             msg = (
318                 "Expected diff isn't equal to the actual. If you made changes to"
319                 " expression.py and this is an anticipated difference, overwrite"
320                 f" tests/data/expression.diff with {dump}"
321             )
322             self.assertEqual(expected, actual, msg)
323
324     def test_expression_diff_with_color(self) -> None:
325         source, _ = read_data("expression.py")
326         config = THIS_DIR / "data" / "empty_pyproject.toml"
327         expected, _ = read_data("expression.diff")
328         tmp_file = Path(black.dump_to_file(source))
329         try:
330             result = BlackRunner().invoke(
331                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
332             )
333         finally:
334             os.unlink(tmp_file)
335         actual = result.output
336         # We check the contents of the diff in `test_expression_diff`. All
337         # we need to check here is that color codes exist in the result.
338         self.assertIn("\033[1;37m", actual)
339         self.assertIn("\033[36m", actual)
340         self.assertIn("\033[32m", actual)
341         self.assertIn("\033[31m", actual)
342         self.assertIn("\033[0m", actual)
343
344     @patch("black.dump_to_file", dump_to_stderr)
345     def test_pep_570(self) -> None:
346         source, expected = read_data("pep_570")
347         actual = fs(source)
348         self.assertFormatEqual(expected, actual)
349         black.assert_stable(source, actual, DEFAULT_MODE)
350         if sys.version_info >= (3, 8):
351             black.assert_equivalent(source, actual)
352
353     def test_detect_pos_only_arguments(self) -> None:
354         source, _ = read_data("pep_570")
355         root = black.lib2to3_parse(source)
356         features = black.get_features_used(root)
357         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
358         versions = black.detect_target_versions(root)
359         self.assertIn(black.TargetVersion.PY38, versions)
360
361     @patch("black.dump_to_file", dump_to_stderr)
362     def test_string_quotes(self) -> None:
363         source, expected = read_data("string_quotes")
364         actual = fs(source)
365         self.assertFormatEqual(expected, actual)
366         black.assert_equivalent(source, actual)
367         black.assert_stable(source, actual, DEFAULT_MODE)
368         mode = replace(DEFAULT_MODE, string_normalization=False)
369         not_normalized = fs(source, mode=mode)
370         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
371         black.assert_equivalent(source, not_normalized)
372         black.assert_stable(source, not_normalized, mode=mode)
373
374     @patch("black.dump_to_file", dump_to_stderr)
375     def test_docstring_no_string_normalization(self) -> None:
376         """Like test_docstring but with string normalization off."""
377         source, expected = read_data("docstring_no_string_normalization")
378         mode = replace(DEFAULT_MODE, string_normalization=False)
379         actual = fs(source, mode=mode)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, mode)
383
384     def test_long_strings_flag_disabled(self) -> None:
385         """Tests for turning off the string processing logic."""
386         source, expected = read_data("long_strings_flag_disabled")
387         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
388         actual = fs(source, mode=mode)
389         self.assertFormatEqual(expected, actual)
390         black.assert_stable(expected, actual, mode)
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_numeric_literals(self) -> None:
394         source, expected = read_data("numeric_literals")
395         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
396         actual = fs(source, mode=mode)
397         self.assertFormatEqual(expected, actual)
398         black.assert_equivalent(source, actual)
399         black.assert_stable(source, actual, mode)
400
401     @patch("black.dump_to_file", dump_to_stderr)
402     def test_numeric_literals_ignoring_underscores(self) -> None:
403         source, expected = read_data("numeric_literals_skip_underscores")
404         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
405         actual = fs(source, mode=mode)
406         self.assertFormatEqual(expected, actual)
407         black.assert_equivalent(source, actual)
408         black.assert_stable(source, actual, mode)
409
410     def test_skip_magic_trailing_comma(self) -> None:
411         source, _ = read_data("expression.py")
412         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
413         tmp_file = Path(black.dump_to_file(source))
414         diff_header = re.compile(
415             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
416             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
417         )
418         try:
419             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
420             self.assertEqual(result.exit_code, 0)
421         finally:
422             os.unlink(tmp_file)
423         actual = result.output
424         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
425         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
426         if expected != actual:
427             dump = black.dump_to_file(actual)
428             msg = (
429                 "Expected diff isn't equal to the actual. If you made changes to"
430                 " expression.py and this is an anticipated difference, overwrite"
431                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
432             )
433             self.assertEqual(expected, actual, msg)
434
435     @patch("black.dump_to_file", dump_to_stderr)
436     def test_python2_print_function(self) -> None:
437         source, expected = read_data("python2_print_function")
438         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
439         actual = fs(source, mode=mode)
440         self.assertFormatEqual(expected, actual)
441         black.assert_equivalent(source, actual)
442         black.assert_stable(source, actual, mode)
443
444     @patch("black.dump_to_file", dump_to_stderr)
445     def test_stub(self) -> None:
446         mode = replace(DEFAULT_MODE, is_pyi=True)
447         source, expected = read_data("stub.pyi")
448         actual = fs(source, mode=mode)
449         self.assertFormatEqual(expected, actual)
450         black.assert_stable(source, actual, mode)
451
452     @patch("black.dump_to_file", dump_to_stderr)
453     def test_async_as_identifier(self) -> None:
454         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
455         source, expected = read_data("async_as_identifier")
456         actual = fs(source)
457         self.assertFormatEqual(expected, actual)
458         major, minor = sys.version_info[:2]
459         if major < 3 or (major <= 3 and minor < 7):
460             black.assert_equivalent(source, actual)
461         black.assert_stable(source, actual, DEFAULT_MODE)
462         # ensure black can parse this when the target is 3.6
463         self.invokeBlack([str(source_path), "--target-version", "py36"])
464         # but not on 3.7, because async/await is no longer an identifier
465         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
466
467     @patch("black.dump_to_file", dump_to_stderr)
468     def test_python37(self) -> None:
469         source_path = (THIS_DIR / "data" / "python37.py").resolve()
470         source, expected = read_data("python37")
471         actual = fs(source)
472         self.assertFormatEqual(expected, actual)
473         major, minor = sys.version_info[:2]
474         if major > 3 or (major == 3 and minor >= 7):
475             black.assert_equivalent(source, actual)
476         black.assert_stable(source, actual, DEFAULT_MODE)
477         # ensure black can parse this when the target is 3.7
478         self.invokeBlack([str(source_path), "--target-version", "py37"])
479         # but not on 3.6, because we use async as a reserved keyword
480         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
481
482     @patch("black.dump_to_file", dump_to_stderr)
483     def test_python38(self) -> None:
484         source, expected = read_data("python38")
485         actual = fs(source)
486         self.assertFormatEqual(expected, actual)
487         major, minor = sys.version_info[:2]
488         if major > 3 or (major == 3 and minor >= 8):
489             black.assert_equivalent(source, actual)
490         black.assert_stable(source, actual, DEFAULT_MODE)
491
492     @patch("black.dump_to_file", dump_to_stderr)
493     def test_python39(self) -> None:
494         source, expected = read_data("python39")
495         actual = fs(source)
496         self.assertFormatEqual(expected, actual)
497         major, minor = sys.version_info[:2]
498         if major > 3 or (major == 3 and minor >= 9):
499             black.assert_equivalent(source, actual)
500         black.assert_stable(source, actual, DEFAULT_MODE)
501
502     def test_tab_comment_indentation(self) -> None:
503         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
504         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
505         self.assertFormatEqual(contents_spc, fs(contents_spc))
506         self.assertFormatEqual(contents_spc, fs(contents_tab))
507
508         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
509         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
510         self.assertFormatEqual(contents_spc, fs(contents_spc))
511         self.assertFormatEqual(contents_spc, fs(contents_tab))
512
513         # mixed tabs and spaces (valid Python 2 code)
514         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
515         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
516         self.assertFormatEqual(contents_spc, fs(contents_spc))
517         self.assertFormatEqual(contents_spc, fs(contents_tab))
518
519         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
520         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
521         self.assertFormatEqual(contents_spc, fs(contents_spc))
522         self.assertFormatEqual(contents_spc, fs(contents_tab))
523
524     def test_report_verbose(self) -> None:
525         report = black.Report(verbose=True)
526         out_lines = []
527         err_lines = []
528
529         def out(msg: str, **kwargs: Any) -> None:
530             out_lines.append(msg)
531
532         def err(msg: str, **kwargs: Any) -> None:
533             err_lines.append(msg)
534
535         with patch("black.out", out), patch("black.err", err):
536             report.done(Path("f1"), black.Changed.NO)
537             self.assertEqual(len(out_lines), 1)
538             self.assertEqual(len(err_lines), 0)
539             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
540             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
541             self.assertEqual(report.return_code, 0)
542             report.done(Path("f2"), black.Changed.YES)
543             self.assertEqual(len(out_lines), 2)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(out_lines[-1], "reformatted f2")
546             self.assertEqual(
547                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
548             )
549             report.done(Path("f3"), black.Changed.CACHED)
550             self.assertEqual(len(out_lines), 3)
551             self.assertEqual(len(err_lines), 0)
552             self.assertEqual(
553                 out_lines[-1], "f3 wasn't modified on disk since last run."
554             )
555             self.assertEqual(
556                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
557             )
558             self.assertEqual(report.return_code, 0)
559             report.check = True
560             self.assertEqual(report.return_code, 1)
561             report.check = False
562             report.failed(Path("e1"), "boom")
563             self.assertEqual(len(out_lines), 3)
564             self.assertEqual(len(err_lines), 1)
565             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
566             self.assertEqual(
567                 unstyle(str(report)),
568                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
569                 " reformat.",
570             )
571             self.assertEqual(report.return_code, 123)
572             report.done(Path("f3"), black.Changed.YES)
573             self.assertEqual(len(out_lines), 4)
574             self.assertEqual(len(err_lines), 1)
575             self.assertEqual(out_lines[-1], "reformatted f3")
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 4)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 5)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(out_lines[-1], "wat ignored: no match")
596             self.assertEqual(
597                 unstyle(str(report)),
598                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
599                 " reformat.",
600             )
601             self.assertEqual(report.return_code, 123)
602             report.done(Path("f4"), black.Changed.NO)
603             self.assertEqual(len(out_lines), 6)
604             self.assertEqual(len(err_lines), 2)
605             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
606             self.assertEqual(
607                 unstyle(str(report)),
608                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
609                 " reformat.",
610             )
611             self.assertEqual(report.return_code, 123)
612             report.check = True
613             self.assertEqual(
614                 unstyle(str(report)),
615                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
616                 " would fail to reformat.",
617             )
618             report.check = False
619             report.diff = True
620             self.assertEqual(
621                 unstyle(str(report)),
622                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
623                 " would fail to reformat.",
624             )
625
626     def test_report_quiet(self) -> None:
627         report = black.Report(quiet=True)
628         out_lines = []
629         err_lines = []
630
631         def out(msg: str, **kwargs: Any) -> None:
632             out_lines.append(msg)
633
634         def err(msg: str, **kwargs: Any) -> None:
635             err_lines.append(msg)
636
637         with patch("black.out", out), patch("black.err", err):
638             report.done(Path("f1"), black.Changed.NO)
639             self.assertEqual(len(out_lines), 0)
640             self.assertEqual(len(err_lines), 0)
641             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
642             self.assertEqual(report.return_code, 0)
643             report.done(Path("f2"), black.Changed.YES)
644             self.assertEqual(len(out_lines), 0)
645             self.assertEqual(len(err_lines), 0)
646             self.assertEqual(
647                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
648             )
649             report.done(Path("f3"), black.Changed.CACHED)
650             self.assertEqual(len(out_lines), 0)
651             self.assertEqual(len(err_lines), 0)
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 0)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(
673                 unstyle(str(report)),
674                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
675                 " reformat.",
676             )
677             self.assertEqual(report.return_code, 123)
678             report.failed(Path("e2"), "boom")
679             self.assertEqual(len(out_lines), 0)
680             self.assertEqual(len(err_lines), 2)
681             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
682             self.assertEqual(
683                 unstyle(str(report)),
684                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
685                 " reformat.",
686             )
687             self.assertEqual(report.return_code, 123)
688             report.path_ignored(Path("wat"), "no match")
689             self.assertEqual(len(out_lines), 0)
690             self.assertEqual(len(err_lines), 2)
691             self.assertEqual(
692                 unstyle(str(report)),
693                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
694                 " reformat.",
695             )
696             self.assertEqual(report.return_code, 123)
697             report.done(Path("f4"), black.Changed.NO)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 2)
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
703                 " reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.check = True
707             self.assertEqual(
708                 unstyle(str(report)),
709                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
710                 " would fail to reformat.",
711             )
712             report.check = False
713             report.diff = True
714             self.assertEqual(
715                 unstyle(str(report)),
716                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
717                 " would fail to reformat.",
718             )
719
720     def test_report_normal(self) -> None:
721         report = black.Report()
722         out_lines = []
723         err_lines = []
724
725         def out(msg: str, **kwargs: Any) -> None:
726             out_lines.append(msg)
727
728         def err(msg: str, **kwargs: Any) -> None:
729             err_lines.append(msg)
730
731         with patch("black.out", out), patch("black.err", err):
732             report.done(Path("f1"), black.Changed.NO)
733             self.assertEqual(len(out_lines), 0)
734             self.assertEqual(len(err_lines), 0)
735             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
736             self.assertEqual(report.return_code, 0)
737             report.done(Path("f2"), black.Changed.YES)
738             self.assertEqual(len(out_lines), 1)
739             self.assertEqual(len(err_lines), 0)
740             self.assertEqual(out_lines[-1], "reformatted f2")
741             self.assertEqual(
742                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
743             )
744             report.done(Path("f3"), black.Changed.CACHED)
745             self.assertEqual(len(out_lines), 1)
746             self.assertEqual(len(err_lines), 0)
747             self.assertEqual(out_lines[-1], "reformatted f2")
748             self.assertEqual(
749                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
750             )
751             self.assertEqual(report.return_code, 0)
752             report.check = True
753             self.assertEqual(report.return_code, 1)
754             report.check = False
755             report.failed(Path("e1"), "boom")
756             self.assertEqual(len(out_lines), 1)
757             self.assertEqual(len(err_lines), 1)
758             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
759             self.assertEqual(
760                 unstyle(str(report)),
761                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
762                 " reformat.",
763             )
764             self.assertEqual(report.return_code, 123)
765             report.done(Path("f3"), black.Changed.YES)
766             self.assertEqual(len(out_lines), 2)
767             self.assertEqual(len(err_lines), 1)
768             self.assertEqual(out_lines[-1], "reformatted f3")
769             self.assertEqual(
770                 unstyle(str(report)),
771                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
772                 " reformat.",
773             )
774             self.assertEqual(report.return_code, 123)
775             report.failed(Path("e2"), "boom")
776             self.assertEqual(len(out_lines), 2)
777             self.assertEqual(len(err_lines), 2)
778             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
779             self.assertEqual(
780                 unstyle(str(report)),
781                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
782                 " reformat.",
783             )
784             self.assertEqual(report.return_code, 123)
785             report.path_ignored(Path("wat"), "no match")
786             self.assertEqual(len(out_lines), 2)
787             self.assertEqual(len(err_lines), 2)
788             self.assertEqual(
789                 unstyle(str(report)),
790                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
791                 " reformat.",
792             )
793             self.assertEqual(report.return_code, 123)
794             report.done(Path("f4"), black.Changed.NO)
795             self.assertEqual(len(out_lines), 2)
796             self.assertEqual(len(err_lines), 2)
797             self.assertEqual(
798                 unstyle(str(report)),
799                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
800                 " reformat.",
801             )
802             self.assertEqual(report.return_code, 123)
803             report.check = True
804             self.assertEqual(
805                 unstyle(str(report)),
806                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
807                 " would fail to reformat.",
808             )
809             report.check = False
810             report.diff = True
811             self.assertEqual(
812                 unstyle(str(report)),
813                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
814                 " would fail to reformat.",
815             )
816
817     def test_lib2to3_parse(self) -> None:
818         with self.assertRaises(black.InvalidInput):
819             black.lib2to3_parse("invalid syntax")
820
821         straddling = "x + y"
822         black.lib2to3_parse(straddling)
823         black.lib2to3_parse(straddling, {TargetVersion.PY27})
824         black.lib2to3_parse(straddling, {TargetVersion.PY36})
825         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
826
827         py2_only = "print x"
828         black.lib2to3_parse(py2_only)
829         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
830         with self.assertRaises(black.InvalidInput):
831             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
832         with self.assertRaises(black.InvalidInput):
833             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
834
835         py3_only = "exec(x, end=y)"
836         black.lib2to3_parse(py3_only)
837         with self.assertRaises(black.InvalidInput):
838             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
839         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
840         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
841
842     def test_get_features_used_decorator(self) -> None:
843         # Test the feature detection of new decorator syntax
844         # since this makes some test cases of test_get_features_used()
845         # fails if it fails, this is tested first so that a useful case
846         # is identified
847         simples, relaxed = read_data("decorators")
848         # skip explanation comments at the top of the file
849         for simple_test in simples.split("##")[1:]:
850             node = black.lib2to3_parse(simple_test)
851             decorator = str(node.children[0].children[0]).strip()
852             self.assertNotIn(
853                 Feature.RELAXED_DECORATORS,
854                 black.get_features_used(node),
855                 msg=(
856                     f"decorator '{decorator}' follows python<=3.8 syntax"
857                     "but is detected as 3.9+"
858                     # f"The full node is\n{node!r}"
859                 ),
860             )
861         # skip the '# output' comment at the top of the output part
862         for relaxed_test in relaxed.split("##")[1:]:
863             node = black.lib2to3_parse(relaxed_test)
864             decorator = str(node.children[0].children[0]).strip()
865             self.assertIn(
866                 Feature.RELAXED_DECORATORS,
867                 black.get_features_used(node),
868                 msg=(
869                     f"decorator '{decorator}' uses python3.9+ syntax"
870                     "but is detected as python<=3.8"
871                     # f"The full node is\n{node!r}"
872                 ),
873             )
874
875     def test_get_features_used(self) -> None:
876         node = black.lib2to3_parse("def f(*, arg): ...\n")
877         self.assertEqual(black.get_features_used(node), set())
878         node = black.lib2to3_parse("def f(*, arg,): ...\n")
879         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
880         node = black.lib2to3_parse("f(*arg,)\n")
881         self.assertEqual(
882             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
883         )
884         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
885         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
886         node = black.lib2to3_parse("123_456\n")
887         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
888         node = black.lib2to3_parse("123456\n")
889         self.assertEqual(black.get_features_used(node), set())
890         source, expected = read_data("function")
891         node = black.lib2to3_parse(source)
892         expected_features = {
893             Feature.TRAILING_COMMA_IN_CALL,
894             Feature.TRAILING_COMMA_IN_DEF,
895             Feature.F_STRINGS,
896         }
897         self.assertEqual(black.get_features_used(node), expected_features)
898         node = black.lib2to3_parse(expected)
899         self.assertEqual(black.get_features_used(node), expected_features)
900         source, expected = read_data("expression")
901         node = black.lib2to3_parse(source)
902         self.assertEqual(black.get_features_used(node), set())
903         node = black.lib2to3_parse(expected)
904         self.assertEqual(black.get_features_used(node), set())
905
906     def test_get_future_imports(self) -> None:
907         node = black.lib2to3_parse("\n")
908         self.assertEqual(set(), black.get_future_imports(node))
909         node = black.lib2to3_parse("from __future__ import black\n")
910         self.assertEqual({"black"}, black.get_future_imports(node))
911         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
912         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
913         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
914         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
915         node = black.lib2to3_parse(
916             "from __future__ import multiple\nfrom __future__ import imports\n"
917         )
918         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
919         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
920         self.assertEqual({"black"}, black.get_future_imports(node))
921         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
922         self.assertEqual({"black"}, black.get_future_imports(node))
923         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
924         self.assertEqual(set(), black.get_future_imports(node))
925         node = black.lib2to3_parse("from some.module import black\n")
926         self.assertEqual(set(), black.get_future_imports(node))
927         node = black.lib2to3_parse(
928             "from __future__ import unicode_literals as _unicode_literals"
929         )
930         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
931         node = black.lib2to3_parse(
932             "from __future__ import unicode_literals as _lol, print"
933         )
934         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
935
936     def test_debug_visitor(self) -> None:
937         source, _ = read_data("debug_visitor.py")
938         expected, _ = read_data("debug_visitor.out")
939         out_lines = []
940         err_lines = []
941
942         def out(msg: str, **kwargs: Any) -> None:
943             out_lines.append(msg)
944
945         def err(msg: str, **kwargs: Any) -> None:
946             err_lines.append(msg)
947
948         with patch("black.out", out), patch("black.err", err):
949             black.DebugVisitor.show(source)
950         actual = "\n".join(out_lines) + "\n"
951         log_name = ""
952         if expected != actual:
953             log_name = black.dump_to_file(*out_lines)
954         self.assertEqual(
955             expected,
956             actual,
957             f"AST print out is different. Actual version dumped to {log_name}",
958         )
959
960     def test_format_file_contents(self) -> None:
961         empty = ""
962         mode = DEFAULT_MODE
963         with self.assertRaises(black.NothingChanged):
964             black.format_file_contents(empty, mode=mode, fast=False)
965         just_nl = "\n"
966         with self.assertRaises(black.NothingChanged):
967             black.format_file_contents(just_nl, mode=mode, fast=False)
968         same = "j = [1, 2, 3]\n"
969         with self.assertRaises(black.NothingChanged):
970             black.format_file_contents(same, mode=mode, fast=False)
971         different = "j = [1,2,3]"
972         expected = same
973         actual = black.format_file_contents(different, mode=mode, fast=False)
974         self.assertEqual(expected, actual)
975         invalid = "return if you can"
976         with self.assertRaises(black.InvalidInput) as e:
977             black.format_file_contents(invalid, mode=mode, fast=False)
978         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
979
980     def test_endmarker(self) -> None:
981         n = black.lib2to3_parse("\n")
982         self.assertEqual(n.type, black.syms.file_input)
983         self.assertEqual(len(n.children), 1)
984         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
985
986     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
987     def test_assertFormatEqual(self) -> None:
988         out_lines = []
989         err_lines = []
990
991         def out(msg: str, **kwargs: Any) -> None:
992             out_lines.append(msg)
993
994         def err(msg: str, **kwargs: Any) -> None:
995             err_lines.append(msg)
996
997         with patch("black.out", out), patch("black.err", err):
998             with self.assertRaises(AssertionError):
999                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1000
1001         out_str = "".join(out_lines)
1002         self.assertTrue("Expected tree:" in out_str)
1003         self.assertTrue("Actual tree:" in out_str)
1004         self.assertEqual("".join(err_lines), "")
1005
1006     def test_cache_broken_file(self) -> None:
1007         mode = DEFAULT_MODE
1008         with cache_dir() as workspace:
1009             cache_file = black.get_cache_file(mode)
1010             with cache_file.open("w") as fobj:
1011                 fobj.write("this is not a pickle")
1012             self.assertEqual(black.read_cache(mode), {})
1013             src = (workspace / "test.py").resolve()
1014             with src.open("w") as fobj:
1015                 fobj.write("print('hello')")
1016             self.invokeBlack([str(src)])
1017             cache = black.read_cache(mode)
1018             self.assertIn(str(src), cache)
1019
1020     def test_cache_single_file_already_cached(self) -> None:
1021         mode = DEFAULT_MODE
1022         with cache_dir() as workspace:
1023             src = (workspace / "test.py").resolve()
1024             with src.open("w") as fobj:
1025                 fobj.write("print('hello')")
1026             black.write_cache({}, [src], mode)
1027             self.invokeBlack([str(src)])
1028             with src.open("r") as fobj:
1029                 self.assertEqual(fobj.read(), "print('hello')")
1030
1031     @event_loop()
1032     def test_cache_multiple_files(self) -> None:
1033         mode = DEFAULT_MODE
1034         with cache_dir() as workspace, patch(
1035             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1036         ):
1037             one = (workspace / "one.py").resolve()
1038             with one.open("w") as fobj:
1039                 fobj.write("print('hello')")
1040             two = (workspace / "two.py").resolve()
1041             with two.open("w") as fobj:
1042                 fobj.write("print('hello')")
1043             black.write_cache({}, [one], mode)
1044             self.invokeBlack([str(workspace)])
1045             with one.open("r") as fobj:
1046                 self.assertEqual(fobj.read(), "print('hello')")
1047             with two.open("r") as fobj:
1048                 self.assertEqual(fobj.read(), 'print("hello")\n')
1049             cache = black.read_cache(mode)
1050             self.assertIn(str(one), cache)
1051             self.assertIn(str(two), cache)
1052
1053     def test_no_cache_when_writeback_diff(self) -> None:
1054         mode = DEFAULT_MODE
1055         with cache_dir() as workspace:
1056             src = (workspace / "test.py").resolve()
1057             with src.open("w") as fobj:
1058                 fobj.write("print('hello')")
1059             with patch("black.read_cache") as read_cache, patch(
1060                 "black.write_cache"
1061             ) as write_cache:
1062                 self.invokeBlack([str(src), "--diff"])
1063                 cache_file = black.get_cache_file(mode)
1064                 self.assertFalse(cache_file.exists())
1065                 write_cache.assert_not_called()
1066                 read_cache.assert_not_called()
1067
1068     def test_no_cache_when_writeback_color_diff(self) -> None:
1069         mode = DEFAULT_MODE
1070         with cache_dir() as workspace:
1071             src = (workspace / "test.py").resolve()
1072             with src.open("w") as fobj:
1073                 fobj.write("print('hello')")
1074             with patch("black.read_cache") as read_cache, patch(
1075                 "black.write_cache"
1076             ) as write_cache:
1077                 self.invokeBlack([str(src), "--diff", "--color"])
1078                 cache_file = black.get_cache_file(mode)
1079                 self.assertFalse(cache_file.exists())
1080                 write_cache.assert_not_called()
1081                 read_cache.assert_not_called()
1082
1083     @event_loop()
1084     def test_output_locking_when_writeback_diff(self) -> None:
1085         with cache_dir() as workspace:
1086             for tag in range(0, 4):
1087                 src = (workspace / f"test{tag}.py").resolve()
1088                 with src.open("w") as fobj:
1089                     fobj.write("print('hello')")
1090             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1091                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1092                 # this isn't quite doing what we want, but if it _isn't_
1093                 # called then we cannot be using the lock it provides
1094                 mgr.assert_called()
1095
1096     @event_loop()
1097     def test_output_locking_when_writeback_color_diff(self) -> None:
1098         with cache_dir() as workspace:
1099             for tag in range(0, 4):
1100                 src = (workspace / f"test{tag}.py").resolve()
1101                 with src.open("w") as fobj:
1102                     fobj.write("print('hello')")
1103             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1104                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1105                 # this isn't quite doing what we want, but if it _isn't_
1106                 # called then we cannot be using the lock it provides
1107                 mgr.assert_called()
1108
1109     def test_no_cache_when_stdin(self) -> None:
1110         mode = DEFAULT_MODE
1111         with cache_dir():
1112             result = CliRunner().invoke(
1113                 black.main, ["-"], input=BytesIO(b"print('hello')")
1114             )
1115             self.assertEqual(result.exit_code, 0)
1116             cache_file = black.get_cache_file(mode)
1117             self.assertFalse(cache_file.exists())
1118
1119     def test_read_cache_no_cachefile(self) -> None:
1120         mode = DEFAULT_MODE
1121         with cache_dir():
1122             self.assertEqual(black.read_cache(mode), {})
1123
1124     def test_write_cache_read_cache(self) -> None:
1125         mode = DEFAULT_MODE
1126         with cache_dir() as workspace:
1127             src = (workspace / "test.py").resolve()
1128             src.touch()
1129             black.write_cache({}, [src], mode)
1130             cache = black.read_cache(mode)
1131             self.assertIn(str(src), cache)
1132             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1133
1134     def test_filter_cached(self) -> None:
1135         with TemporaryDirectory() as workspace:
1136             path = Path(workspace)
1137             uncached = (path / "uncached").resolve()
1138             cached = (path / "cached").resolve()
1139             cached_but_changed = (path / "changed").resolve()
1140             uncached.touch()
1141             cached.touch()
1142             cached_but_changed.touch()
1143             cache = {
1144                 str(cached): black.get_cache_info(cached),
1145                 str(cached_but_changed): (0.0, 0),
1146             }
1147             todo, done = black.filter_cached(
1148                 cache, {uncached, cached, cached_but_changed}
1149             )
1150             self.assertEqual(todo, {uncached, cached_but_changed})
1151             self.assertEqual(done, {cached})
1152
1153     def test_write_cache_creates_directory_if_needed(self) -> None:
1154         mode = DEFAULT_MODE
1155         with cache_dir(exists=False) as workspace:
1156             self.assertFalse(workspace.exists())
1157             black.write_cache({}, [], mode)
1158             self.assertTrue(workspace.exists())
1159
1160     @event_loop()
1161     def test_failed_formatting_does_not_get_cached(self) -> None:
1162         mode = DEFAULT_MODE
1163         with cache_dir() as workspace, patch(
1164             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1165         ):
1166             failing = (workspace / "failing.py").resolve()
1167             with failing.open("w") as fobj:
1168                 fobj.write("not actually python")
1169             clean = (workspace / "clean.py").resolve()
1170             with clean.open("w") as fobj:
1171                 fobj.write('print("hello")\n')
1172             self.invokeBlack([str(workspace)], exit_code=123)
1173             cache = black.read_cache(mode)
1174             self.assertNotIn(str(failing), cache)
1175             self.assertIn(str(clean), cache)
1176
1177     def test_write_cache_write_fail(self) -> None:
1178         mode = DEFAULT_MODE
1179         with cache_dir(), patch.object(Path, "open") as mock:
1180             mock.side_effect = OSError
1181             black.write_cache({}, [], mode)
1182
1183     @event_loop()
1184     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1185     def test_works_in_mono_process_only_environment(self) -> None:
1186         with cache_dir() as workspace:
1187             for f in [
1188                 (workspace / "one.py").resolve(),
1189                 (workspace / "two.py").resolve(),
1190             ]:
1191                 f.write_text('print("hello")\n')
1192             self.invokeBlack([str(workspace)])
1193
1194     @event_loop()
1195     def test_check_diff_use_together(self) -> None:
1196         with cache_dir():
1197             # Files which will be reformatted.
1198             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1199             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1200             # Files which will not be reformatted.
1201             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1202             self.invokeBlack([str(src2), "--diff", "--check"])
1203             # Multi file command.
1204             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1205
1206     def test_no_files(self) -> None:
1207         with cache_dir():
1208             # Without an argument, black exits with error code 0.
1209             self.invokeBlack([])
1210
1211     def test_broken_symlink(self) -> None:
1212         with cache_dir() as workspace:
1213             symlink = workspace / "broken_link.py"
1214             try:
1215                 symlink.symlink_to("nonexistent.py")
1216             except OSError as e:
1217                 self.skipTest(f"Can't create symlinks: {e}")
1218             self.invokeBlack([str(workspace.resolve())])
1219
1220     def test_read_cache_line_lengths(self) -> None:
1221         mode = DEFAULT_MODE
1222         short_mode = replace(DEFAULT_MODE, line_length=1)
1223         with cache_dir() as workspace:
1224             path = (workspace / "file.py").resolve()
1225             path.touch()
1226             black.write_cache({}, [path], mode)
1227             one = black.read_cache(mode)
1228             self.assertIn(str(path), one)
1229             two = black.read_cache(short_mode)
1230             self.assertNotIn(str(path), two)
1231
1232     def test_single_file_force_pyi(self) -> None:
1233         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1234         contents, expected = read_data("force_pyi")
1235         with cache_dir() as workspace:
1236             path = (workspace / "file.py").resolve()
1237             with open(path, "w") as fh:
1238                 fh.write(contents)
1239             self.invokeBlack([str(path), "--pyi"])
1240             with open(path, "r") as fh:
1241                 actual = fh.read()
1242             # verify cache with --pyi is separate
1243             pyi_cache = black.read_cache(pyi_mode)
1244             self.assertIn(str(path), pyi_cache)
1245             normal_cache = black.read_cache(DEFAULT_MODE)
1246             self.assertNotIn(str(path), normal_cache)
1247         self.assertFormatEqual(expected, actual)
1248         black.assert_equivalent(contents, actual)
1249         black.assert_stable(contents, actual, pyi_mode)
1250
1251     @event_loop()
1252     def test_multi_file_force_pyi(self) -> None:
1253         reg_mode = DEFAULT_MODE
1254         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1255         contents, expected = read_data("force_pyi")
1256         with cache_dir() as workspace:
1257             paths = [
1258                 (workspace / "file1.py").resolve(),
1259                 (workspace / "file2.py").resolve(),
1260             ]
1261             for path in paths:
1262                 with open(path, "w") as fh:
1263                     fh.write(contents)
1264             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1265             for path in paths:
1266                 with open(path, "r") as fh:
1267                     actual = fh.read()
1268                 self.assertEqual(actual, expected)
1269             # verify cache with --pyi is separate
1270             pyi_cache = black.read_cache(pyi_mode)
1271             normal_cache = black.read_cache(reg_mode)
1272             for path in paths:
1273                 self.assertIn(str(path), pyi_cache)
1274                 self.assertNotIn(str(path), normal_cache)
1275
1276     def test_pipe_force_pyi(self) -> None:
1277         source, expected = read_data("force_pyi")
1278         result = CliRunner().invoke(
1279             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1280         )
1281         self.assertEqual(result.exit_code, 0)
1282         actual = result.output
1283         self.assertFormatEqual(actual, expected)
1284
1285     def test_single_file_force_py36(self) -> None:
1286         reg_mode = DEFAULT_MODE
1287         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1288         source, expected = read_data("force_py36")
1289         with cache_dir() as workspace:
1290             path = (workspace / "file.py").resolve()
1291             with open(path, "w") as fh:
1292                 fh.write(source)
1293             self.invokeBlack([str(path), *PY36_ARGS])
1294             with open(path, "r") as fh:
1295                 actual = fh.read()
1296             # verify cache with --target-version is separate
1297             py36_cache = black.read_cache(py36_mode)
1298             self.assertIn(str(path), py36_cache)
1299             normal_cache = black.read_cache(reg_mode)
1300             self.assertNotIn(str(path), normal_cache)
1301         self.assertEqual(actual, expected)
1302
1303     @event_loop()
1304     def test_multi_file_force_py36(self) -> None:
1305         reg_mode = DEFAULT_MODE
1306         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1307         source, expected = read_data("force_py36")
1308         with cache_dir() as workspace:
1309             paths = [
1310                 (workspace / "file1.py").resolve(),
1311                 (workspace / "file2.py").resolve(),
1312             ]
1313             for path in paths:
1314                 with open(path, "w") as fh:
1315                     fh.write(source)
1316             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1317             for path in paths:
1318                 with open(path, "r") as fh:
1319                     actual = fh.read()
1320                 self.assertEqual(actual, expected)
1321             # verify cache with --target-version is separate
1322             pyi_cache = black.read_cache(py36_mode)
1323             normal_cache = black.read_cache(reg_mode)
1324             for path in paths:
1325                 self.assertIn(str(path), pyi_cache)
1326                 self.assertNotIn(str(path), normal_cache)
1327
1328     def test_pipe_force_py36(self) -> None:
1329         source, expected = read_data("force_py36")
1330         result = CliRunner().invoke(
1331             black.main,
1332             ["-", "-q", "--target-version=py36"],
1333             input=BytesIO(source.encode("utf8")),
1334         )
1335         self.assertEqual(result.exit_code, 0)
1336         actual = result.output
1337         self.assertFormatEqual(actual, expected)
1338
1339     def test_include_exclude(self) -> None:
1340         path = THIS_DIR / "data" / "include_exclude_tests"
1341         include = re.compile(r"\.pyi?$")
1342         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1343         report = black.Report()
1344         gitignore = PathSpec.from_lines("gitwildmatch", [])
1345         sources: List[Path] = []
1346         expected = [
1347             Path(path / "b/dont_exclude/a.py"),
1348             Path(path / "b/dont_exclude/a.pyi"),
1349         ]
1350         this_abs = THIS_DIR.resolve()
1351         sources.extend(
1352             black.gen_python_files(
1353                 path.iterdir(),
1354                 this_abs,
1355                 include,
1356                 exclude,
1357                 None,
1358                 None,
1359                 report,
1360                 gitignore,
1361             )
1362         )
1363         self.assertEqual(sorted(expected), sorted(sources))
1364
1365     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1366     def test_exclude_for_issue_1572(self) -> None:
1367         # Exclude shouldn't touch files that were explicitly given to Black through the
1368         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1369         # https://github.com/psf/black/issues/1572
1370         path = THIS_DIR / "data" / "include_exclude_tests"
1371         include = ""
1372         exclude = r"/exclude/|a\.py"
1373         src = str(path / "b/exclude/a.py")
1374         report = black.Report()
1375         expected = [Path(path / "b/exclude/a.py")]
1376         sources = list(
1377             black.get_sources(
1378                 ctx=FakeContext(),
1379                 src=(src,),
1380                 quiet=True,
1381                 verbose=False,
1382                 include=re.compile(include),
1383                 exclude=re.compile(exclude),
1384                 extend_exclude=None,
1385                 force_exclude=None,
1386                 report=report,
1387                 stdin_filename=None,
1388             )
1389         )
1390         self.assertEqual(sorted(expected), sorted(sources))
1391
1392     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1393     def test_get_sources_with_stdin(self) -> None:
1394         include = ""
1395         exclude = r"/exclude/|a\.py"
1396         src = "-"
1397         report = black.Report()
1398         expected = [Path("-")]
1399         sources = list(
1400             black.get_sources(
1401                 ctx=FakeContext(),
1402                 src=(src,),
1403                 quiet=True,
1404                 verbose=False,
1405                 include=re.compile(include),
1406                 exclude=re.compile(exclude),
1407                 extend_exclude=None,
1408                 force_exclude=None,
1409                 report=report,
1410                 stdin_filename=None,
1411             )
1412         )
1413         self.assertEqual(sorted(expected), sorted(sources))
1414
1415     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1416     def test_get_sources_with_stdin_filename(self) -> None:
1417         include = ""
1418         exclude = r"/exclude/|a\.py"
1419         src = "-"
1420         report = black.Report()
1421         stdin_filename = str(THIS_DIR / "data/collections.py")
1422         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1423         sources = list(
1424             black.get_sources(
1425                 ctx=FakeContext(),
1426                 src=(src,),
1427                 quiet=True,
1428                 verbose=False,
1429                 include=re.compile(include),
1430                 exclude=re.compile(exclude),
1431                 extend_exclude=None,
1432                 force_exclude=None,
1433                 report=report,
1434                 stdin_filename=stdin_filename,
1435             )
1436         )
1437         self.assertEqual(sorted(expected), sorted(sources))
1438
1439     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1440     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1441         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1442         # file being passed directly. This is the same as
1443         # test_exclude_for_issue_1572
1444         path = THIS_DIR / "data" / "include_exclude_tests"
1445         include = ""
1446         exclude = r"/exclude/|a\.py"
1447         src = "-"
1448         report = black.Report()
1449         stdin_filename = str(path / "b/exclude/a.py")
1450         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1451         sources = list(
1452             black.get_sources(
1453                 ctx=FakeContext(),
1454                 src=(src,),
1455                 quiet=True,
1456                 verbose=False,
1457                 include=re.compile(include),
1458                 exclude=re.compile(exclude),
1459                 extend_exclude=None,
1460                 force_exclude=None,
1461                 report=report,
1462                 stdin_filename=stdin_filename,
1463             )
1464         )
1465         self.assertEqual(sorted(expected), sorted(sources))
1466
1467     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1468     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1469         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1470         # file being passed directly. This is the same as
1471         # test_exclude_for_issue_1572
1472         path = THIS_DIR / "data" / "include_exclude_tests"
1473         include = ""
1474         extend_exclude = r"/exclude/|a\.py"
1475         src = "-"
1476         report = black.Report()
1477         stdin_filename = str(path / "b/exclude/a.py")
1478         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1479         sources = list(
1480             black.get_sources(
1481                 ctx=FakeContext(),
1482                 src=(src,),
1483                 quiet=True,
1484                 verbose=False,
1485                 include=re.compile(include),
1486                 exclude=re.compile(""),
1487                 extend_exclude=re.compile(extend_exclude),
1488                 force_exclude=None,
1489                 report=report,
1490                 stdin_filename=stdin_filename,
1491             )
1492         )
1493         self.assertEqual(sorted(expected), sorted(sources))
1494
1495     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1496     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1497         # Force exclude should exclude the file when passing it through
1498         # stdin_filename
1499         path = THIS_DIR / "data" / "include_exclude_tests"
1500         include = ""
1501         force_exclude = r"/exclude/|a\.py"
1502         src = "-"
1503         report = black.Report()
1504         stdin_filename = str(path / "b/exclude/a.py")
1505         sources = list(
1506             black.get_sources(
1507                 ctx=FakeContext(),
1508                 src=(src,),
1509                 quiet=True,
1510                 verbose=False,
1511                 include=re.compile(include),
1512                 exclude=re.compile(""),
1513                 extend_exclude=None,
1514                 force_exclude=re.compile(force_exclude),
1515                 report=report,
1516                 stdin_filename=stdin_filename,
1517             )
1518         )
1519         self.assertEqual([], sorted(sources))
1520
1521     def test_reformat_one_with_stdin(self) -> None:
1522         with patch(
1523             "black.format_stdin_to_stdout",
1524             return_value=lambda *args, **kwargs: black.Changed.YES,
1525         ) as fsts:
1526             report = MagicMock()
1527             path = Path("-")
1528             black.reformat_one(
1529                 path,
1530                 fast=True,
1531                 write_back=black.WriteBack.YES,
1532                 mode=DEFAULT_MODE,
1533                 report=report,
1534             )
1535             fsts.assert_called_once()
1536             report.done.assert_called_with(path, black.Changed.YES)
1537
1538     def test_reformat_one_with_stdin_filename(self) -> None:
1539         with patch(
1540             "black.format_stdin_to_stdout",
1541             return_value=lambda *args, **kwargs: black.Changed.YES,
1542         ) as fsts:
1543             report = MagicMock()
1544             p = "foo.py"
1545             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1546             expected = Path(p)
1547             black.reformat_one(
1548                 path,
1549                 fast=True,
1550                 write_back=black.WriteBack.YES,
1551                 mode=DEFAULT_MODE,
1552                 report=report,
1553             )
1554             fsts.assert_called_once()
1555             # __BLACK_STDIN_FILENAME__ should have been striped
1556             report.done.assert_called_with(expected, black.Changed.YES)
1557
1558     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1559         with patch(
1560             "black.format_stdin_to_stdout",
1561             return_value=lambda *args, **kwargs: black.Changed.YES,
1562         ) as fsts:
1563             report = MagicMock()
1564             # Even with an existing file, since we are forcing stdin, black
1565             # should output to stdout and not modify the file inplace
1566             p = Path(str(THIS_DIR / "data/collections.py"))
1567             # Make sure is_file actually returns True
1568             self.assertTrue(p.is_file())
1569             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1570             expected = Path(p)
1571             black.reformat_one(
1572                 path,
1573                 fast=True,
1574                 write_back=black.WriteBack.YES,
1575                 mode=DEFAULT_MODE,
1576                 report=report,
1577             )
1578             fsts.assert_called_once()
1579             # __BLACK_STDIN_FILENAME__ should have been striped
1580             report.done.assert_called_with(expected, black.Changed.YES)
1581
1582     def test_gitignore_exclude(self) -> None:
1583         path = THIS_DIR / "data" / "include_exclude_tests"
1584         include = re.compile(r"\.pyi?$")
1585         exclude = re.compile(r"")
1586         report = black.Report()
1587         gitignore = PathSpec.from_lines(
1588             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1589         )
1590         sources: List[Path] = []
1591         expected = [
1592             Path(path / "b/dont_exclude/a.py"),
1593             Path(path / "b/dont_exclude/a.pyi"),
1594         ]
1595         this_abs = THIS_DIR.resolve()
1596         sources.extend(
1597             black.gen_python_files(
1598                 path.iterdir(),
1599                 this_abs,
1600                 include,
1601                 exclude,
1602                 None,
1603                 None,
1604                 report,
1605                 gitignore,
1606             )
1607         )
1608         self.assertEqual(sorted(expected), sorted(sources))
1609
1610     def test_empty_include(self) -> None:
1611         path = THIS_DIR / "data" / "include_exclude_tests"
1612         report = black.Report()
1613         gitignore = PathSpec.from_lines("gitwildmatch", [])
1614         empty = re.compile(r"")
1615         sources: List[Path] = []
1616         expected = [
1617             Path(path / "b/exclude/a.pie"),
1618             Path(path / "b/exclude/a.py"),
1619             Path(path / "b/exclude/a.pyi"),
1620             Path(path / "b/dont_exclude/a.pie"),
1621             Path(path / "b/dont_exclude/a.py"),
1622             Path(path / "b/dont_exclude/a.pyi"),
1623             Path(path / "b/.definitely_exclude/a.pie"),
1624             Path(path / "b/.definitely_exclude/a.py"),
1625             Path(path / "b/.definitely_exclude/a.pyi"),
1626         ]
1627         this_abs = THIS_DIR.resolve()
1628         sources.extend(
1629             black.gen_python_files(
1630                 path.iterdir(),
1631                 this_abs,
1632                 empty,
1633                 re.compile(black.DEFAULT_EXCLUDES),
1634                 None,
1635                 None,
1636                 report,
1637                 gitignore,
1638             )
1639         )
1640         self.assertEqual(sorted(expected), sorted(sources))
1641
1642     def test_extend_exclude(self) -> None:
1643         path = THIS_DIR / "data" / "include_exclude_tests"
1644         report = black.Report()
1645         gitignore = PathSpec.from_lines("gitwildmatch", [])
1646         sources: List[Path] = []
1647         expected = [
1648             Path(path / "b/exclude/a.py"),
1649             Path(path / "b/dont_exclude/a.py"),
1650         ]
1651         this_abs = THIS_DIR.resolve()
1652         sources.extend(
1653             black.gen_python_files(
1654                 path.iterdir(),
1655                 this_abs,
1656                 re.compile(black.DEFAULT_INCLUDES),
1657                 re.compile(r"\.pyi$"),
1658                 re.compile(r"\.definitely_exclude"),
1659                 None,
1660                 report,
1661                 gitignore,
1662             )
1663         )
1664         self.assertEqual(sorted(expected), sorted(sources))
1665
1666     def test_invalid_cli_regex(self) -> None:
1667         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1668             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1669
1670     def test_preserves_line_endings(self) -> None:
1671         with TemporaryDirectory() as workspace:
1672             test_file = Path(workspace) / "test.py"
1673             for nl in ["\n", "\r\n"]:
1674                 contents = nl.join(["def f(  ):", "    pass"])
1675                 test_file.write_bytes(contents.encode())
1676                 ff(test_file, write_back=black.WriteBack.YES)
1677                 updated_contents: bytes = test_file.read_bytes()
1678                 self.assertIn(nl.encode(), updated_contents)
1679                 if nl == "\n":
1680                     self.assertNotIn(b"\r\n", updated_contents)
1681
1682     def test_preserves_line_endings_via_stdin(self) -> None:
1683         for nl in ["\n", "\r\n"]:
1684             contents = nl.join(["def f(  ):", "    pass"])
1685             runner = BlackRunner()
1686             result = runner.invoke(
1687                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1688             )
1689             self.assertEqual(result.exit_code, 0)
1690             output = runner.stdout_bytes
1691             self.assertIn(nl.encode("utf8"), output)
1692             if nl == "\n":
1693                 self.assertNotIn(b"\r\n", output)
1694
1695     def test_assert_equivalent_different_asts(self) -> None:
1696         with self.assertRaises(AssertionError):
1697             black.assert_equivalent("{}", "None")
1698
1699     def test_symlink_out_of_root_directory(self) -> None:
1700         path = MagicMock()
1701         root = THIS_DIR.resolve()
1702         child = MagicMock()
1703         include = re.compile(black.DEFAULT_INCLUDES)
1704         exclude = re.compile(black.DEFAULT_EXCLUDES)
1705         report = black.Report()
1706         gitignore = PathSpec.from_lines("gitwildmatch", [])
1707         # `child` should behave like a symlink which resolved path is clearly
1708         # outside of the `root` directory.
1709         path.iterdir.return_value = [child]
1710         child.resolve.return_value = Path("/a/b/c")
1711         child.as_posix.return_value = "/a/b/c"
1712         child.is_symlink.return_value = True
1713         try:
1714             list(
1715                 black.gen_python_files(
1716                     path.iterdir(),
1717                     root,
1718                     include,
1719                     exclude,
1720                     None,
1721                     None,
1722                     report,
1723                     gitignore,
1724                 )
1725             )
1726         except ValueError as ve:
1727             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1728         path.iterdir.assert_called_once()
1729         child.resolve.assert_called_once()
1730         child.is_symlink.assert_called_once()
1731         # `child` should behave like a strange file which resolved path is clearly
1732         # outside of the `root` directory.
1733         child.is_symlink.return_value = False
1734         with self.assertRaises(ValueError):
1735             list(
1736                 black.gen_python_files(
1737                     path.iterdir(),
1738                     root,
1739                     include,
1740                     exclude,
1741                     None,
1742                     None,
1743                     report,
1744                     gitignore,
1745                 )
1746             )
1747         path.iterdir.assert_called()
1748         self.assertEqual(path.iterdir.call_count, 2)
1749         child.resolve.assert_called()
1750         self.assertEqual(child.resolve.call_count, 2)
1751         child.is_symlink.assert_called()
1752         self.assertEqual(child.is_symlink.call_count, 2)
1753
1754     def test_shhh_click(self) -> None:
1755         try:
1756             from click import _unicodefun  # type: ignore
1757         except ModuleNotFoundError:
1758             self.skipTest("Incompatible Click version")
1759         if not hasattr(_unicodefun, "_verify_python3_env"):
1760             self.skipTest("Incompatible Click version")
1761         # First, let's see if Click is crashing with a preferred ASCII charset.
1762         with patch("locale.getpreferredencoding") as gpe:
1763             gpe.return_value = "ASCII"
1764             with self.assertRaises(RuntimeError):
1765                 _unicodefun._verify_python3_env()
1766         # Now, let's silence Click...
1767         black.patch_click()
1768         # ...and confirm it's silent.
1769         with patch("locale.getpreferredencoding") as gpe:
1770             gpe.return_value = "ASCII"
1771             try:
1772                 _unicodefun._verify_python3_env()
1773             except RuntimeError as re:
1774                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1775
1776     def test_root_logger_not_used_directly(self) -> None:
1777         def fail(*args: Any, **kwargs: Any) -> None:
1778             self.fail("Record created with root logger")
1779
1780         with patch.multiple(
1781             logging.root,
1782             debug=fail,
1783             info=fail,
1784             warning=fail,
1785             error=fail,
1786             critical=fail,
1787             log=fail,
1788         ):
1789             ff(THIS_FILE)
1790
1791     def test_invalid_config_return_code(self) -> None:
1792         tmp_file = Path(black.dump_to_file())
1793         try:
1794             tmp_config = Path(black.dump_to_file())
1795             tmp_config.unlink()
1796             args = ["--config", str(tmp_config), str(tmp_file)]
1797             self.invokeBlack(args, exit_code=2, ignore_config=False)
1798         finally:
1799             tmp_file.unlink()
1800
1801     def test_parse_pyproject_toml(self) -> None:
1802         test_toml_file = THIS_DIR / "test.toml"
1803         config = black.parse_pyproject_toml(str(test_toml_file))
1804         self.assertEqual(config["verbose"], 1)
1805         self.assertEqual(config["check"], "no")
1806         self.assertEqual(config["diff"], "y")
1807         self.assertEqual(config["color"], True)
1808         self.assertEqual(config["line_length"], 79)
1809         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1810         self.assertEqual(config["exclude"], r"\.pyi?$")
1811         self.assertEqual(config["include"], r"\.py?$")
1812
1813     def test_read_pyproject_toml(self) -> None:
1814         test_toml_file = THIS_DIR / "test.toml"
1815         fake_ctx = FakeContext()
1816         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1817         config = fake_ctx.default_map
1818         self.assertEqual(config["verbose"], "1")
1819         self.assertEqual(config["check"], "no")
1820         self.assertEqual(config["diff"], "y")
1821         self.assertEqual(config["color"], "True")
1822         self.assertEqual(config["line_length"], "79")
1823         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1824         self.assertEqual(config["exclude"], r"\.pyi?$")
1825         self.assertEqual(config["include"], r"\.py?$")
1826
1827     def test_find_project_root(self) -> None:
1828         with TemporaryDirectory() as workspace:
1829             root = Path(workspace)
1830             test_dir = root / "test"
1831             test_dir.mkdir()
1832
1833             src_dir = root / "src"
1834             src_dir.mkdir()
1835
1836             root_pyproject = root / "pyproject.toml"
1837             root_pyproject.touch()
1838             src_pyproject = src_dir / "pyproject.toml"
1839             src_pyproject.touch()
1840             src_python = src_dir / "foo.py"
1841             src_python.touch()
1842
1843             self.assertEqual(
1844                 black.find_project_root((src_dir, test_dir)), root.resolve()
1845             )
1846             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1847             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1848
1849     @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
1850     def test_find_user_pyproject_toml_linux(self) -> None:
1851         if system() == "Windows":
1852             return
1853
1854         # Test if XDG_CONFIG_HOME is checked
1855         with TemporaryDirectory() as workspace:
1856             tmp_user_config = Path(workspace) / "black"
1857             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1858                 self.assertEqual(
1859                     black.find_user_pyproject_toml(), tmp_user_config.resolve()
1860                 )
1861
1862         # Test fallback for XDG_CONFIG_HOME
1863         with patch.dict("os.environ"):
1864             os.environ.pop("XDG_CONFIG_HOME", None)
1865             fallback_user_config = Path("~/.config").expanduser() / "black"
1866             self.assertEqual(
1867                 black.find_user_pyproject_toml(), fallback_user_config.resolve()
1868             )
1869
1870     def test_find_user_pyproject_toml_windows(self) -> None:
1871         if system() != "Windows":
1872             return
1873
1874         user_config_path = Path.home() / ".black"
1875         self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
1876
1877     def test_bpo_33660_workaround(self) -> None:
1878         if system() == "Windows":
1879             return
1880
1881         # https://bugs.python.org/issue33660
1882
1883         old_cwd = Path.cwd()
1884         try:
1885             root = Path("/")
1886             os.chdir(str(root))
1887             path = Path("workspace") / "project"
1888             report = black.Report(verbose=True)
1889             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1890             self.assertEqual(normalized_path, "workspace/project")
1891         finally:
1892             os.chdir(str(old_cwd))
1893
1894     def test_newline_comment_interaction(self) -> None:
1895         source = "class A:\\\r\n# type: ignore\n pass\n"
1896         output = black.format_str(source, mode=DEFAULT_MODE)
1897         black.assert_stable(source, output, mode=DEFAULT_MODE)
1898
1899     def test_bpo_2142_workaround(self) -> None:
1900
1901         # https://bugs.python.org/issue2142
1902
1903         source, _ = read_data("missing_final_newline.py")
1904         # read_data adds a trailing newline
1905         source = source.rstrip()
1906         expected, _ = read_data("missing_final_newline.diff")
1907         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1908         diff_header = re.compile(
1909             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1910             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1911         )
1912         try:
1913             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1914             self.assertEqual(result.exit_code, 0)
1915         finally:
1916             os.unlink(tmp_file)
1917         actual = result.output
1918         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1919         self.assertEqual(actual, expected)
1920
1921     def test_docstring_reformat_for_py27(self) -> None:
1922         """
1923         Check that stripping trailing whitespace from Python 2 docstrings
1924         doesn't trigger a "not equivalent to source" error
1925         """
1926         source = (
1927             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1928         )
1929         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1930
1931         result = CliRunner().invoke(
1932             black.main,
1933             ["-", "-q", "--target-version=py27"],
1934             input=BytesIO(source),
1935         )
1936
1937         self.assertEqual(result.exit_code, 0)
1938         actual = result.output
1939         self.assertFormatEqual(actual, expected)
1940
1941
1942 with open(black.__file__, "r", encoding="utf-8") as _bf:
1943     black_source_lines = _bf.readlines()
1944
1945
1946 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
1947     """Show function calls `from black/__init__.py` as they happen.
1948
1949     Register this with `sys.settrace()` in a test you're debugging.
1950     """
1951     if event != "call":
1952         return tracefunc
1953
1954     stack = len(inspect.stack()) - 19
1955     stack *= 2
1956     filename = frame.f_code.co_filename
1957     lineno = frame.f_lineno
1958     func_sig_lineno = lineno - 1
1959     funcname = black_source_lines[func_sig_lineno].strip()
1960     while funcname.startswith("@"):
1961         func_sig_lineno += 1
1962         funcname = black_source_lines[func_sig_lineno].strip()
1963     if "black/__init__.py" in filename:
1964         print(f"{' ' * stack}{lineno}:{funcname}")
1965     return tracefunc
1966
1967
1968 if __name__ == "__main__":
1969     unittest.main(module="test_black")