]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use optional tests for "no_python2" to simplify local testing (#2203)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import pytest
28 import unittest
29 from unittest.mock import patch, MagicMock
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37
38 from pathspec import PathSpec
39
40 # Import other test classes
41 from tests.util import (
42     THIS_DIR,
43     read_data,
44     DETERMINISTIC_HEADER,
45     BlackBaseTestCase,
46     DEFAULT_MODE,
47     fs,
48     ff,
49     dump_to_stderr,
50 )
51 from .test_primer import PrimerCLITests  # noqa: F401
52
53
54 THIS_FILE = Path(__file__)
55 PY36_VERSIONS = {
56     TargetVersion.PY36,
57     TargetVersion.PY37,
58     TargetVersion.PY38,
59     TargetVersion.PY39,
60 }
61 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
62 T = TypeVar("T")
63 R = TypeVar("R")
64
65
66 @contextmanager
67 def cache_dir(exists: bool = True) -> Iterator[Path]:
68     with TemporaryDirectory() as workspace:
69         cache_dir = Path(workspace)
70         if not exists:
71             cache_dir = cache_dir / "new"
72         with patch("black.CACHE_DIR", cache_dir):
73             yield cache_dir
74
75
76 @contextmanager
77 def event_loop() -> Iterator[None]:
78     policy = asyncio.get_event_loop_policy()
79     loop = policy.new_event_loop()
80     asyncio.set_event_loop(loop)
81     try:
82         yield
83
84     finally:
85         loop.close()
86
87
88 class FakeContext(click.Context):
89     """A fake click Context for when calling functions that need it."""
90
91     def __init__(self) -> None:
92         self.default_map: Dict[str, Any] = {}
93
94
95 class FakeParameter(click.Parameter):
96     """A fake click Parameter for when calling functions that need it."""
97
98     def __init__(self) -> None:
99         pass
100
101
102 class BlackRunner(CliRunner):
103     """Modify CliRunner so that stderr is not merged with stdout.
104
105     This is a hack that can be removed once we depend on Click 7.x"""
106
107     def __init__(self) -> None:
108         self.stderrbuf = BytesIO()
109         self.stdoutbuf = BytesIO()
110         self.stdout_bytes = b""
111         self.stderr_bytes = b""
112         super().__init__()
113
114     @contextmanager
115     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
116         with super().isolation(*args, **kwargs) as output:
117             try:
118                 hold_stderr = sys.stderr
119                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
120                 yield output
121             finally:
122                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
123                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
124                 sys.stderr = hold_stderr
125
126
127 class BlackTestCase(BlackBaseTestCase):
128     def invokeBlack(
129         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
130     ) -> None:
131         runner = BlackRunner()
132         if ignore_config:
133             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
134         result = runner.invoke(black.main, args)
135         self.assertEqual(
136             result.exit_code,
137             exit_code,
138             msg=(
139                 f"Failed with args: {args}\n"
140                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
141                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
142                 f"exception: {result.exception}"
143             ),
144         )
145
146     @patch("black.dump_to_file", dump_to_stderr)
147     def test_empty(self) -> None:
148         source = expected = ""
149         actual = fs(source)
150         self.assertFormatEqual(expected, actual)
151         black.assert_equivalent(source, actual)
152         black.assert_stable(source, actual, DEFAULT_MODE)
153
154     def test_empty_ff(self) -> None:
155         expected = ""
156         tmp_file = Path(black.dump_to_file())
157         try:
158             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
159             with open(tmp_file, encoding="utf8") as f:
160                 actual = f.read()
161         finally:
162             os.unlink(tmp_file)
163         self.assertFormatEqual(expected, actual)
164
165     def test_piping(self) -> None:
166         source, expected = read_data("src/black/__init__", data=False)
167         result = BlackRunner().invoke(
168             black.main,
169             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
170             input=BytesIO(source.encode("utf8")),
171         )
172         self.assertEqual(result.exit_code, 0)
173         self.assertFormatEqual(expected, result.output)
174         black.assert_equivalent(source, result.output)
175         black.assert_stable(source, result.output, DEFAULT_MODE)
176
177     def test_piping_diff(self) -> None:
178         diff_header = re.compile(
179             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
180             r"\+\d\d\d\d"
181         )
182         source, _ = read_data("expression.py")
183         expected, _ = read_data("expression.diff")
184         config = THIS_DIR / "data" / "empty_pyproject.toml"
185         args = [
186             "-",
187             "--fast",
188             f"--line-length={black.DEFAULT_LINE_LENGTH}",
189             "--diff",
190             f"--config={config}",
191         ]
192         result = BlackRunner().invoke(
193             black.main, args, input=BytesIO(source.encode("utf8"))
194         )
195         self.assertEqual(result.exit_code, 0)
196         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
197         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
198         self.assertEqual(expected, actual)
199
200     def test_piping_diff_with_color(self) -> None:
201         source, _ = read_data("expression.py")
202         config = THIS_DIR / "data" / "empty_pyproject.toml"
203         args = [
204             "-",
205             "--fast",
206             f"--line-length={black.DEFAULT_LINE_LENGTH}",
207             "--diff",
208             "--color",
209             f"--config={config}",
210         ]
211         result = BlackRunner().invoke(
212             black.main, args, input=BytesIO(source.encode("utf8"))
213         )
214         actual = result.output
215         # Again, the contents are checked in a different test, so only look for colors.
216         self.assertIn("\033[1;37m", actual)
217         self.assertIn("\033[36m", actual)
218         self.assertIn("\033[32m", actual)
219         self.assertIn("\033[31m", actual)
220         self.assertIn("\033[0m", actual)
221
222     @patch("black.dump_to_file", dump_to_stderr)
223     def _test_wip(self) -> None:
224         source, expected = read_data("wip")
225         sys.settrace(tracefunc)
226         mode = replace(
227             DEFAULT_MODE,
228             experimental_string_processing=False,
229             target_versions={black.TargetVersion.PY38},
230         )
231         actual = fs(source, mode=mode)
232         sys.settrace(None)
233         self.assertFormatEqual(expected, actual)
234         black.assert_equivalent(source, actual)
235         black.assert_stable(source, actual, black.FileMode())
236
237     @unittest.expectedFailure
238     @patch("black.dump_to_file", dump_to_stderr)
239     def test_trailing_comma_optional_parens_stability1(self) -> None:
240         source, _expected = read_data("trailing_comma_optional_parens1")
241         actual = fs(source)
242         black.assert_stable(source, actual, DEFAULT_MODE)
243
244     @unittest.expectedFailure
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens2")
248         actual = fs(source)
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @unittest.expectedFailure
252     @patch("black.dump_to_file", dump_to_stderr)
253     def test_trailing_comma_optional_parens_stability3(self) -> None:
254         source, _expected = read_data("trailing_comma_optional_parens3")
255         actual = fs(source)
256         black.assert_stable(source, actual, DEFAULT_MODE)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
260         source, _expected = read_data("trailing_comma_optional_parens1")
261         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
262         black.assert_stable(source, actual, DEFAULT_MODE)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
266         source, _expected = read_data("trailing_comma_optional_parens2")
267         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
268         black.assert_stable(source, actual, DEFAULT_MODE)
269
270     @patch("black.dump_to_file", dump_to_stderr)
271     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
272         source, _expected = read_data("trailing_comma_optional_parens3")
273         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
274         black.assert_stable(source, actual, DEFAULT_MODE)
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_pep_572(self) -> None:
278         source, expected = read_data("pep_572")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_stable(source, actual, DEFAULT_MODE)
282         if sys.version_info >= (3, 8):
283             black.assert_equivalent(source, actual)
284
285     @patch("black.dump_to_file", dump_to_stderr)
286     def test_pep_572_remove_parens(self) -> None:
287         source, expected = read_data("pep_572_remove_parens")
288         actual = fs(source)
289         self.assertFormatEqual(expected, actual)
290         black.assert_stable(source, actual, DEFAULT_MODE)
291         if sys.version_info >= (3, 8):
292             black.assert_equivalent(source, actual)
293
294     @patch("black.dump_to_file", dump_to_stderr)
295     def test_pep_572_do_not_remove_parens(self) -> None:
296         source, expected = read_data("pep_572_do_not_remove_parens")
297         # the AST safety checks will fail, but that's expected, just make sure no
298         # parentheses are touched
299         actual = black.format_str(source, mode=DEFAULT_MODE)
300         self.assertFormatEqual(expected, actual)
301
302     def test_pep_572_version_detection(self) -> None:
303         source, _ = read_data("pep_572")
304         root = black.lib2to3_parse(source)
305         features = black.get_features_used(root)
306         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
307         versions = black.detect_target_versions(root)
308         self.assertIn(black.TargetVersion.PY38, versions)
309
310     def test_expression_ff(self) -> None:
311         source, expected = read_data("expression")
312         tmp_file = Path(black.dump_to_file(source))
313         try:
314             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
315             with open(tmp_file, encoding="utf8") as f:
316                 actual = f.read()
317         finally:
318             os.unlink(tmp_file)
319         self.assertFormatEqual(expected, actual)
320         with patch("black.dump_to_file", dump_to_stderr):
321             black.assert_equivalent(source, actual)
322             black.assert_stable(source, actual, DEFAULT_MODE)
323
324     def test_expression_diff(self) -> None:
325         source, _ = read_data("expression.py")
326         config = THIS_DIR / "data" / "empty_pyproject.toml"
327         expected, _ = read_data("expression.diff")
328         tmp_file = Path(black.dump_to_file(source))
329         diff_header = re.compile(
330             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
331             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
332         )
333         try:
334             result = BlackRunner().invoke(
335                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
336             )
337             self.assertEqual(result.exit_code, 0)
338         finally:
339             os.unlink(tmp_file)
340         actual = result.output
341         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
342         if expected != actual:
343             dump = black.dump_to_file(actual)
344             msg = (
345                 "Expected diff isn't equal to the actual. If you made changes to"
346                 " expression.py and this is an anticipated difference, overwrite"
347                 f" tests/data/expression.diff with {dump}"
348             )
349             self.assertEqual(expected, actual, msg)
350
351     def test_expression_diff_with_color(self) -> None:
352         source, _ = read_data("expression.py")
353         config = THIS_DIR / "data" / "empty_pyproject.toml"
354         expected, _ = read_data("expression.diff")
355         tmp_file = Path(black.dump_to_file(source))
356         try:
357             result = BlackRunner().invoke(
358                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
359             )
360         finally:
361             os.unlink(tmp_file)
362         actual = result.output
363         # We check the contents of the diff in `test_expression_diff`. All
364         # we need to check here is that color codes exist in the result.
365         self.assertIn("\033[1;37m", actual)
366         self.assertIn("\033[36m", actual)
367         self.assertIn("\033[32m", actual)
368         self.assertIn("\033[31m", actual)
369         self.assertIn("\033[0m", actual)
370
371     @patch("black.dump_to_file", dump_to_stderr)
372     def test_pep_570(self) -> None:
373         source, expected = read_data("pep_570")
374         actual = fs(source)
375         self.assertFormatEqual(expected, actual)
376         black.assert_stable(source, actual, DEFAULT_MODE)
377         if sys.version_info >= (3, 8):
378             black.assert_equivalent(source, actual)
379
380     def test_detect_pos_only_arguments(self) -> None:
381         source, _ = read_data("pep_570")
382         root = black.lib2to3_parse(source)
383         features = black.get_features_used(root)
384         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
385         versions = black.detect_target_versions(root)
386         self.assertIn(black.TargetVersion.PY38, versions)
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_string_quotes(self) -> None:
390         source, expected = read_data("string_quotes")
391         mode = black.Mode(experimental_string_processing=True)
392         actual = fs(source, mode=mode)
393         self.assertFormatEqual(expected, actual)
394         black.assert_equivalent(source, actual)
395         black.assert_stable(source, actual, mode)
396         mode = replace(mode, string_normalization=False)
397         not_normalized = fs(source, mode=mode)
398         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
399         black.assert_equivalent(source, not_normalized)
400         black.assert_stable(source, not_normalized, mode=mode)
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_docstring_no_string_normalization(self) -> None:
404         """Like test_docstring but with string normalization off."""
405         source, expected = read_data("docstring_no_string_normalization")
406         mode = replace(DEFAULT_MODE, string_normalization=False)
407         actual = fs(source, mode=mode)
408         self.assertFormatEqual(expected, actual)
409         black.assert_equivalent(source, actual)
410         black.assert_stable(source, actual, mode)
411
412     def test_long_strings_flag_disabled(self) -> None:
413         """Tests for turning off the string processing logic."""
414         source, expected = read_data("long_strings_flag_disabled")
415         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
416         actual = fs(source, mode=mode)
417         self.assertFormatEqual(expected, actual)
418         black.assert_stable(expected, actual, mode)
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_numeric_literals(self) -> None:
422         source, expected = read_data("numeric_literals")
423         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
424         actual = fs(source, mode=mode)
425         self.assertFormatEqual(expected, actual)
426         black.assert_equivalent(source, actual)
427         black.assert_stable(source, actual, mode)
428
429     @patch("black.dump_to_file", dump_to_stderr)
430     def test_numeric_literals_ignoring_underscores(self) -> None:
431         source, expected = read_data("numeric_literals_skip_underscores")
432         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
433         actual = fs(source, mode=mode)
434         self.assertFormatEqual(expected, actual)
435         black.assert_equivalent(source, actual)
436         black.assert_stable(source, actual, mode)
437
438     def test_skip_magic_trailing_comma(self) -> None:
439         source, _ = read_data("expression.py")
440         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
441         tmp_file = Path(black.dump_to_file(source))
442         diff_header = re.compile(
443             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
444             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
445         )
446         try:
447             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
448             self.assertEqual(result.exit_code, 0)
449         finally:
450             os.unlink(tmp_file)
451         actual = result.output
452         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
453         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
454         if expected != actual:
455             dump = black.dump_to_file(actual)
456             msg = (
457                 "Expected diff isn't equal to the actual. If you made changes to"
458                 " expression.py and this is an anticipated difference, overwrite"
459                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
460             )
461             self.assertEqual(expected, actual, msg)
462
463     @pytest.mark.no_python2
464     def test_python2_should_fail_without_optional_install(self) -> None:
465         if sys.version_info < (3, 8):
466             self.skipTest(
467                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
468                 " able to parse Python 2 syntax without explicitly specifying the"
469                 " python2 extra"
470             )
471
472         source = "x = 1234l"
473         tmp_file = Path(black.dump_to_file(source))
474         try:
475             runner = BlackRunner()
476             result = runner.invoke(black.main, [str(tmp_file)])
477             self.assertEqual(result.exit_code, 123)
478         finally:
479             os.unlink(tmp_file)
480         actual = (
481             runner.stderr_bytes.decode()
482             .replace("\n", "")
483             .replace("\\n", "")
484             .replace("\\r", "")
485             .replace("\r", "")
486         )
487         msg = (
488             "The requested source code has invalid Python 3 syntax."
489             "If you are trying to format Python 2 files please reinstall Black"
490             " with the 'python2' extra: `python3 -m pip install black[python2]`."
491         )
492         self.assertIn(msg, actual)
493
494     @pytest.mark.python2
495     @patch("black.dump_to_file", dump_to_stderr)
496     def test_python2_print_function(self) -> None:
497         source, expected = read_data("python2_print_function")
498         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
499         actual = fs(source, mode=mode)
500         self.assertFormatEqual(expected, actual)
501         black.assert_equivalent(source, actual)
502         black.assert_stable(source, actual, mode)
503
504     @patch("black.dump_to_file", dump_to_stderr)
505     def test_stub(self) -> None:
506         mode = replace(DEFAULT_MODE, is_pyi=True)
507         source, expected = read_data("stub.pyi")
508         actual = fs(source, mode=mode)
509         self.assertFormatEqual(expected, actual)
510         black.assert_stable(source, actual, mode)
511
512     @patch("black.dump_to_file", dump_to_stderr)
513     def test_async_as_identifier(self) -> None:
514         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
515         source, expected = read_data("async_as_identifier")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         major, minor = sys.version_info[:2]
519         if major < 3 or (major <= 3 and minor < 7):
520             black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, DEFAULT_MODE)
522         # ensure black can parse this when the target is 3.6
523         self.invokeBlack([str(source_path), "--target-version", "py36"])
524         # but not on 3.7, because async/await is no longer an identifier
525         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
526
527     @patch("black.dump_to_file", dump_to_stderr)
528     def test_python37(self) -> None:
529         source_path = (THIS_DIR / "data" / "python37.py").resolve()
530         source, expected = read_data("python37")
531         actual = fs(source)
532         self.assertFormatEqual(expected, actual)
533         major, minor = sys.version_info[:2]
534         if major > 3 or (major == 3 and minor >= 7):
535             black.assert_equivalent(source, actual)
536         black.assert_stable(source, actual, DEFAULT_MODE)
537         # ensure black can parse this when the target is 3.7
538         self.invokeBlack([str(source_path), "--target-version", "py37"])
539         # but not on 3.6, because we use async as a reserved keyword
540         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
541
542     @patch("black.dump_to_file", dump_to_stderr)
543     def test_python38(self) -> None:
544         source, expected = read_data("python38")
545         actual = fs(source)
546         self.assertFormatEqual(expected, actual)
547         major, minor = sys.version_info[:2]
548         if major > 3 or (major == 3 and minor >= 8):
549             black.assert_equivalent(source, actual)
550         black.assert_stable(source, actual, DEFAULT_MODE)
551
552     @patch("black.dump_to_file", dump_to_stderr)
553     def test_python39(self) -> None:
554         source, expected = read_data("python39")
555         actual = fs(source)
556         self.assertFormatEqual(expected, actual)
557         major, minor = sys.version_info[:2]
558         if major > 3 or (major == 3 and minor >= 9):
559             black.assert_equivalent(source, actual)
560         black.assert_stable(source, actual, DEFAULT_MODE)
561
562     def test_tab_comment_indentation(self) -> None:
563         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
564         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
565         self.assertFormatEqual(contents_spc, fs(contents_spc))
566         self.assertFormatEqual(contents_spc, fs(contents_tab))
567
568         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
569         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
570         self.assertFormatEqual(contents_spc, fs(contents_spc))
571         self.assertFormatEqual(contents_spc, fs(contents_tab))
572
573         # mixed tabs and spaces (valid Python 2 code)
574         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
575         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
576         self.assertFormatEqual(contents_spc, fs(contents_spc))
577         self.assertFormatEqual(contents_spc, fs(contents_tab))
578
579         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
580         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
581         self.assertFormatEqual(contents_spc, fs(contents_spc))
582         self.assertFormatEqual(contents_spc, fs(contents_tab))
583
584     def test_report_verbose(self) -> None:
585         report = black.Report(verbose=True)
586         out_lines = []
587         err_lines = []
588
589         def out(msg: str, **kwargs: Any) -> None:
590             out_lines.append(msg)
591
592         def err(msg: str, **kwargs: Any) -> None:
593             err_lines.append(msg)
594
595         with patch("black.out", out), patch("black.err", err):
596             report.done(Path("f1"), black.Changed.NO)
597             self.assertEqual(len(out_lines), 1)
598             self.assertEqual(len(err_lines), 0)
599             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
600             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
601             self.assertEqual(report.return_code, 0)
602             report.done(Path("f2"), black.Changed.YES)
603             self.assertEqual(len(out_lines), 2)
604             self.assertEqual(len(err_lines), 0)
605             self.assertEqual(out_lines[-1], "reformatted f2")
606             self.assertEqual(
607                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
608             )
609             report.done(Path("f3"), black.Changed.CACHED)
610             self.assertEqual(len(out_lines), 3)
611             self.assertEqual(len(err_lines), 0)
612             self.assertEqual(
613                 out_lines[-1], "f3 wasn't modified on disk since last run."
614             )
615             self.assertEqual(
616                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
617             )
618             self.assertEqual(report.return_code, 0)
619             report.check = True
620             self.assertEqual(report.return_code, 1)
621             report.check = False
622             report.failed(Path("e1"), "boom")
623             self.assertEqual(len(out_lines), 3)
624             self.assertEqual(len(err_lines), 1)
625             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
626             self.assertEqual(
627                 unstyle(str(report)),
628                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
629                 " reformat.",
630             )
631             self.assertEqual(report.return_code, 123)
632             report.done(Path("f3"), black.Changed.YES)
633             self.assertEqual(len(out_lines), 4)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(out_lines[-1], "reformatted f3")
636             self.assertEqual(
637                 unstyle(str(report)),
638                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
639                 " reformat.",
640             )
641             self.assertEqual(report.return_code, 123)
642             report.failed(Path("e2"), "boom")
643             self.assertEqual(len(out_lines), 4)
644             self.assertEqual(len(err_lines), 2)
645             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
646             self.assertEqual(
647                 unstyle(str(report)),
648                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
649                 " reformat.",
650             )
651             self.assertEqual(report.return_code, 123)
652             report.path_ignored(Path("wat"), "no match")
653             self.assertEqual(len(out_lines), 5)
654             self.assertEqual(len(err_lines), 2)
655             self.assertEqual(out_lines[-1], "wat ignored: no match")
656             self.assertEqual(
657                 unstyle(str(report)),
658                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
659                 " reformat.",
660             )
661             self.assertEqual(report.return_code, 123)
662             report.done(Path("f4"), black.Changed.NO)
663             self.assertEqual(len(out_lines), 6)
664             self.assertEqual(len(err_lines), 2)
665             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
666             self.assertEqual(
667                 unstyle(str(report)),
668                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
669                 " reformat.",
670             )
671             self.assertEqual(report.return_code, 123)
672             report.check = True
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
676                 " would fail to reformat.",
677             )
678             report.check = False
679             report.diff = True
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
683                 " would fail to reformat.",
684             )
685
686     def test_report_quiet(self) -> None:
687         report = black.Report(quiet=True)
688         out_lines = []
689         err_lines = []
690
691         def out(msg: str, **kwargs: Any) -> None:
692             out_lines.append(msg)
693
694         def err(msg: str, **kwargs: Any) -> None:
695             err_lines.append(msg)
696
697         with patch("black.out", out), patch("black.err", err):
698             report.done(Path("f1"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 0)
700             self.assertEqual(len(err_lines), 0)
701             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
702             self.assertEqual(report.return_code, 0)
703             report.done(Path("f2"), black.Changed.YES)
704             self.assertEqual(len(out_lines), 0)
705             self.assertEqual(len(err_lines), 0)
706             self.assertEqual(
707                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
708             )
709             report.done(Path("f3"), black.Changed.CACHED)
710             self.assertEqual(len(out_lines), 0)
711             self.assertEqual(len(err_lines), 0)
712             self.assertEqual(
713                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
714             )
715             self.assertEqual(report.return_code, 0)
716             report.check = True
717             self.assertEqual(report.return_code, 1)
718             report.check = False
719             report.failed(Path("e1"), "boom")
720             self.assertEqual(len(out_lines), 0)
721             self.assertEqual(len(err_lines), 1)
722             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
723             self.assertEqual(
724                 unstyle(str(report)),
725                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
726                 " reformat.",
727             )
728             self.assertEqual(report.return_code, 123)
729             report.done(Path("f3"), black.Changed.YES)
730             self.assertEqual(len(out_lines), 0)
731             self.assertEqual(len(err_lines), 1)
732             self.assertEqual(
733                 unstyle(str(report)),
734                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
735                 " reformat.",
736             )
737             self.assertEqual(report.return_code, 123)
738             report.failed(Path("e2"), "boom")
739             self.assertEqual(len(out_lines), 0)
740             self.assertEqual(len(err_lines), 2)
741             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
742             self.assertEqual(
743                 unstyle(str(report)),
744                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
745                 " reformat.",
746             )
747             self.assertEqual(report.return_code, 123)
748             report.path_ignored(Path("wat"), "no match")
749             self.assertEqual(len(out_lines), 0)
750             self.assertEqual(len(err_lines), 2)
751             self.assertEqual(
752                 unstyle(str(report)),
753                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
754                 " reformat.",
755             )
756             self.assertEqual(report.return_code, 123)
757             report.done(Path("f4"), black.Changed.NO)
758             self.assertEqual(len(out_lines), 0)
759             self.assertEqual(len(err_lines), 2)
760             self.assertEqual(
761                 unstyle(str(report)),
762                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
763                 " reformat.",
764             )
765             self.assertEqual(report.return_code, 123)
766             report.check = True
767             self.assertEqual(
768                 unstyle(str(report)),
769                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
770                 " would fail to reformat.",
771             )
772             report.check = False
773             report.diff = True
774             self.assertEqual(
775                 unstyle(str(report)),
776                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
777                 " would fail to reformat.",
778             )
779
780     def test_report_normal(self) -> None:
781         report = black.Report()
782         out_lines = []
783         err_lines = []
784
785         def out(msg: str, **kwargs: Any) -> None:
786             out_lines.append(msg)
787
788         def err(msg: str, **kwargs: Any) -> None:
789             err_lines.append(msg)
790
791         with patch("black.out", out), patch("black.err", err):
792             report.done(Path("f1"), black.Changed.NO)
793             self.assertEqual(len(out_lines), 0)
794             self.assertEqual(len(err_lines), 0)
795             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
796             self.assertEqual(report.return_code, 0)
797             report.done(Path("f2"), black.Changed.YES)
798             self.assertEqual(len(out_lines), 1)
799             self.assertEqual(len(err_lines), 0)
800             self.assertEqual(out_lines[-1], "reformatted f2")
801             self.assertEqual(
802                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
803             )
804             report.done(Path("f3"), black.Changed.CACHED)
805             self.assertEqual(len(out_lines), 1)
806             self.assertEqual(len(err_lines), 0)
807             self.assertEqual(out_lines[-1], "reformatted f2")
808             self.assertEqual(
809                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
810             )
811             self.assertEqual(report.return_code, 0)
812             report.check = True
813             self.assertEqual(report.return_code, 1)
814             report.check = False
815             report.failed(Path("e1"), "boom")
816             self.assertEqual(len(out_lines), 1)
817             self.assertEqual(len(err_lines), 1)
818             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
819             self.assertEqual(
820                 unstyle(str(report)),
821                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
822                 " reformat.",
823             )
824             self.assertEqual(report.return_code, 123)
825             report.done(Path("f3"), black.Changed.YES)
826             self.assertEqual(len(out_lines), 2)
827             self.assertEqual(len(err_lines), 1)
828             self.assertEqual(out_lines[-1], "reformatted f3")
829             self.assertEqual(
830                 unstyle(str(report)),
831                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
832                 " reformat.",
833             )
834             self.assertEqual(report.return_code, 123)
835             report.failed(Path("e2"), "boom")
836             self.assertEqual(len(out_lines), 2)
837             self.assertEqual(len(err_lines), 2)
838             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
839             self.assertEqual(
840                 unstyle(str(report)),
841                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
842                 " reformat.",
843             )
844             self.assertEqual(report.return_code, 123)
845             report.path_ignored(Path("wat"), "no match")
846             self.assertEqual(len(out_lines), 2)
847             self.assertEqual(len(err_lines), 2)
848             self.assertEqual(
849                 unstyle(str(report)),
850                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
851                 " reformat.",
852             )
853             self.assertEqual(report.return_code, 123)
854             report.done(Path("f4"), black.Changed.NO)
855             self.assertEqual(len(out_lines), 2)
856             self.assertEqual(len(err_lines), 2)
857             self.assertEqual(
858                 unstyle(str(report)),
859                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
860                 " reformat.",
861             )
862             self.assertEqual(report.return_code, 123)
863             report.check = True
864             self.assertEqual(
865                 unstyle(str(report)),
866                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
867                 " would fail to reformat.",
868             )
869             report.check = False
870             report.diff = True
871             self.assertEqual(
872                 unstyle(str(report)),
873                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
874                 " would fail to reformat.",
875             )
876
877     def test_lib2to3_parse(self) -> None:
878         with self.assertRaises(black.InvalidInput):
879             black.lib2to3_parse("invalid syntax")
880
881         straddling = "x + y"
882         black.lib2to3_parse(straddling)
883         black.lib2to3_parse(straddling, {TargetVersion.PY27})
884         black.lib2to3_parse(straddling, {TargetVersion.PY36})
885         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
886
887         py2_only = "print x"
888         black.lib2to3_parse(py2_only)
889         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
890         with self.assertRaises(black.InvalidInput):
891             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
892         with self.assertRaises(black.InvalidInput):
893             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
894
895         py3_only = "exec(x, end=y)"
896         black.lib2to3_parse(py3_only)
897         with self.assertRaises(black.InvalidInput):
898             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
899         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
900         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
901
902     def test_get_features_used_decorator(self) -> None:
903         # Test the feature detection of new decorator syntax
904         # since this makes some test cases of test_get_features_used()
905         # fails if it fails, this is tested first so that a useful case
906         # is identified
907         simples, relaxed = read_data("decorators")
908         # skip explanation comments at the top of the file
909         for simple_test in simples.split("##")[1:]:
910             node = black.lib2to3_parse(simple_test)
911             decorator = str(node.children[0].children[0]).strip()
912             self.assertNotIn(
913                 Feature.RELAXED_DECORATORS,
914                 black.get_features_used(node),
915                 msg=(
916                     f"decorator '{decorator}' follows python<=3.8 syntax"
917                     "but is detected as 3.9+"
918                     # f"The full node is\n{node!r}"
919                 ),
920             )
921         # skip the '# output' comment at the top of the output part
922         for relaxed_test in relaxed.split("##")[1:]:
923             node = black.lib2to3_parse(relaxed_test)
924             decorator = str(node.children[0].children[0]).strip()
925             self.assertIn(
926                 Feature.RELAXED_DECORATORS,
927                 black.get_features_used(node),
928                 msg=(
929                     f"decorator '{decorator}' uses python3.9+ syntax"
930                     "but is detected as python<=3.8"
931                     # f"The full node is\n{node!r}"
932                 ),
933             )
934
935     def test_get_features_used(self) -> None:
936         node = black.lib2to3_parse("def f(*, arg): ...\n")
937         self.assertEqual(black.get_features_used(node), set())
938         node = black.lib2to3_parse("def f(*, arg,): ...\n")
939         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
940         node = black.lib2to3_parse("f(*arg,)\n")
941         self.assertEqual(
942             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
943         )
944         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
945         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
946         node = black.lib2to3_parse("123_456\n")
947         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
948         node = black.lib2to3_parse("123456\n")
949         self.assertEqual(black.get_features_used(node), set())
950         source, expected = read_data("function")
951         node = black.lib2to3_parse(source)
952         expected_features = {
953             Feature.TRAILING_COMMA_IN_CALL,
954             Feature.TRAILING_COMMA_IN_DEF,
955             Feature.F_STRINGS,
956         }
957         self.assertEqual(black.get_features_used(node), expected_features)
958         node = black.lib2to3_parse(expected)
959         self.assertEqual(black.get_features_used(node), expected_features)
960         source, expected = read_data("expression")
961         node = black.lib2to3_parse(source)
962         self.assertEqual(black.get_features_used(node), set())
963         node = black.lib2to3_parse(expected)
964         self.assertEqual(black.get_features_used(node), set())
965
966     def test_get_future_imports(self) -> None:
967         node = black.lib2to3_parse("\n")
968         self.assertEqual(set(), black.get_future_imports(node))
969         node = black.lib2to3_parse("from __future__ import black\n")
970         self.assertEqual({"black"}, black.get_future_imports(node))
971         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
972         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
973         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
974         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
975         node = black.lib2to3_parse(
976             "from __future__ import multiple\nfrom __future__ import imports\n"
977         )
978         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
979         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
980         self.assertEqual({"black"}, black.get_future_imports(node))
981         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
982         self.assertEqual({"black"}, black.get_future_imports(node))
983         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
984         self.assertEqual(set(), black.get_future_imports(node))
985         node = black.lib2to3_parse("from some.module import black\n")
986         self.assertEqual(set(), black.get_future_imports(node))
987         node = black.lib2to3_parse(
988             "from __future__ import unicode_literals as _unicode_literals"
989         )
990         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
991         node = black.lib2to3_parse(
992             "from __future__ import unicode_literals as _lol, print"
993         )
994         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
995
996     def test_debug_visitor(self) -> None:
997         source, _ = read_data("debug_visitor.py")
998         expected, _ = read_data("debug_visitor.out")
999         out_lines = []
1000         err_lines = []
1001
1002         def out(msg: str, **kwargs: Any) -> None:
1003             out_lines.append(msg)
1004
1005         def err(msg: str, **kwargs: Any) -> None:
1006             err_lines.append(msg)
1007
1008         with patch("black.out", out), patch("black.err", err):
1009             black.DebugVisitor.show(source)
1010         actual = "\n".join(out_lines) + "\n"
1011         log_name = ""
1012         if expected != actual:
1013             log_name = black.dump_to_file(*out_lines)
1014         self.assertEqual(
1015             expected,
1016             actual,
1017             f"AST print out is different. Actual version dumped to {log_name}",
1018         )
1019
1020     def test_format_file_contents(self) -> None:
1021         empty = ""
1022         mode = DEFAULT_MODE
1023         with self.assertRaises(black.NothingChanged):
1024             black.format_file_contents(empty, mode=mode, fast=False)
1025         just_nl = "\n"
1026         with self.assertRaises(black.NothingChanged):
1027             black.format_file_contents(just_nl, mode=mode, fast=False)
1028         same = "j = [1, 2, 3]\n"
1029         with self.assertRaises(black.NothingChanged):
1030             black.format_file_contents(same, mode=mode, fast=False)
1031         different = "j = [1,2,3]"
1032         expected = same
1033         actual = black.format_file_contents(different, mode=mode, fast=False)
1034         self.assertEqual(expected, actual)
1035         invalid = "return if you can"
1036         with self.assertRaises(black.InvalidInput) as e:
1037             black.format_file_contents(invalid, mode=mode, fast=False)
1038         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1039
1040     def test_endmarker(self) -> None:
1041         n = black.lib2to3_parse("\n")
1042         self.assertEqual(n.type, black.syms.file_input)
1043         self.assertEqual(len(n.children), 1)
1044         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1045
1046     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1047     def test_assertFormatEqual(self) -> None:
1048         out_lines = []
1049         err_lines = []
1050
1051         def out(msg: str, **kwargs: Any) -> None:
1052             out_lines.append(msg)
1053
1054         def err(msg: str, **kwargs: Any) -> None:
1055             err_lines.append(msg)
1056
1057         with patch("black.out", out), patch("black.err", err):
1058             with self.assertRaises(AssertionError):
1059                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1060
1061         out_str = "".join(out_lines)
1062         self.assertTrue("Expected tree:" in out_str)
1063         self.assertTrue("Actual tree:" in out_str)
1064         self.assertEqual("".join(err_lines), "")
1065
1066     def test_cache_broken_file(self) -> None:
1067         mode = DEFAULT_MODE
1068         with cache_dir() as workspace:
1069             cache_file = black.get_cache_file(mode)
1070             with cache_file.open("w") as fobj:
1071                 fobj.write("this is not a pickle")
1072             self.assertEqual(black.read_cache(mode), {})
1073             src = (workspace / "test.py").resolve()
1074             with src.open("w") as fobj:
1075                 fobj.write("print('hello')")
1076             self.invokeBlack([str(src)])
1077             cache = black.read_cache(mode)
1078             self.assertIn(str(src), cache)
1079
1080     def test_cache_single_file_already_cached(self) -> None:
1081         mode = DEFAULT_MODE
1082         with cache_dir() as workspace:
1083             src = (workspace / "test.py").resolve()
1084             with src.open("w") as fobj:
1085                 fobj.write("print('hello')")
1086             black.write_cache({}, [src], mode)
1087             self.invokeBlack([str(src)])
1088             with src.open("r") as fobj:
1089                 self.assertEqual(fobj.read(), "print('hello')")
1090
1091     @event_loop()
1092     def test_cache_multiple_files(self) -> None:
1093         mode = DEFAULT_MODE
1094         with cache_dir() as workspace, patch(
1095             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1096         ):
1097             one = (workspace / "one.py").resolve()
1098             with one.open("w") as fobj:
1099                 fobj.write("print('hello')")
1100             two = (workspace / "two.py").resolve()
1101             with two.open("w") as fobj:
1102                 fobj.write("print('hello')")
1103             black.write_cache({}, [one], mode)
1104             self.invokeBlack([str(workspace)])
1105             with one.open("r") as fobj:
1106                 self.assertEqual(fobj.read(), "print('hello')")
1107             with two.open("r") as fobj:
1108                 self.assertEqual(fobj.read(), 'print("hello")\n')
1109             cache = black.read_cache(mode)
1110             self.assertIn(str(one), cache)
1111             self.assertIn(str(two), cache)
1112
1113     def test_no_cache_when_writeback_diff(self) -> None:
1114         mode = DEFAULT_MODE
1115         with cache_dir() as workspace:
1116             src = (workspace / "test.py").resolve()
1117             with src.open("w") as fobj:
1118                 fobj.write("print('hello')")
1119             with patch("black.read_cache") as read_cache, patch(
1120                 "black.write_cache"
1121             ) as write_cache:
1122                 self.invokeBlack([str(src), "--diff"])
1123                 cache_file = black.get_cache_file(mode)
1124                 self.assertFalse(cache_file.exists())
1125                 write_cache.assert_not_called()
1126                 read_cache.assert_not_called()
1127
1128     def test_no_cache_when_writeback_color_diff(self) -> None:
1129         mode = DEFAULT_MODE
1130         with cache_dir() as workspace:
1131             src = (workspace / "test.py").resolve()
1132             with src.open("w") as fobj:
1133                 fobj.write("print('hello')")
1134             with patch("black.read_cache") as read_cache, patch(
1135                 "black.write_cache"
1136             ) as write_cache:
1137                 self.invokeBlack([str(src), "--diff", "--color"])
1138                 cache_file = black.get_cache_file(mode)
1139                 self.assertFalse(cache_file.exists())
1140                 write_cache.assert_not_called()
1141                 read_cache.assert_not_called()
1142
1143     @event_loop()
1144     def test_output_locking_when_writeback_diff(self) -> None:
1145         with cache_dir() as workspace:
1146             for tag in range(0, 4):
1147                 src = (workspace / f"test{tag}.py").resolve()
1148                 with src.open("w") as fobj:
1149                     fobj.write("print('hello')")
1150             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1151                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1152                 # this isn't quite doing what we want, but if it _isn't_
1153                 # called then we cannot be using the lock it provides
1154                 mgr.assert_called()
1155
1156     @event_loop()
1157     def test_output_locking_when_writeback_color_diff(self) -> None:
1158         with cache_dir() as workspace:
1159             for tag in range(0, 4):
1160                 src = (workspace / f"test{tag}.py").resolve()
1161                 with src.open("w") as fobj:
1162                     fobj.write("print('hello')")
1163             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1164                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1165                 # this isn't quite doing what we want, but if it _isn't_
1166                 # called then we cannot be using the lock it provides
1167                 mgr.assert_called()
1168
1169     def test_no_cache_when_stdin(self) -> None:
1170         mode = DEFAULT_MODE
1171         with cache_dir():
1172             result = CliRunner().invoke(
1173                 black.main, ["-"], input=BytesIO(b"print('hello')")
1174             )
1175             self.assertEqual(result.exit_code, 0)
1176             cache_file = black.get_cache_file(mode)
1177             self.assertFalse(cache_file.exists())
1178
1179     def test_read_cache_no_cachefile(self) -> None:
1180         mode = DEFAULT_MODE
1181         with cache_dir():
1182             self.assertEqual(black.read_cache(mode), {})
1183
1184     def test_write_cache_read_cache(self) -> None:
1185         mode = DEFAULT_MODE
1186         with cache_dir() as workspace:
1187             src = (workspace / "test.py").resolve()
1188             src.touch()
1189             black.write_cache({}, [src], mode)
1190             cache = black.read_cache(mode)
1191             self.assertIn(str(src), cache)
1192             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1193
1194     def test_filter_cached(self) -> None:
1195         with TemporaryDirectory() as workspace:
1196             path = Path(workspace)
1197             uncached = (path / "uncached").resolve()
1198             cached = (path / "cached").resolve()
1199             cached_but_changed = (path / "changed").resolve()
1200             uncached.touch()
1201             cached.touch()
1202             cached_but_changed.touch()
1203             cache = {
1204                 str(cached): black.get_cache_info(cached),
1205                 str(cached_but_changed): (0.0, 0),
1206             }
1207             todo, done = black.filter_cached(
1208                 cache, {uncached, cached, cached_but_changed}
1209             )
1210             self.assertEqual(todo, {uncached, cached_but_changed})
1211             self.assertEqual(done, {cached})
1212
1213     def test_write_cache_creates_directory_if_needed(self) -> None:
1214         mode = DEFAULT_MODE
1215         with cache_dir(exists=False) as workspace:
1216             self.assertFalse(workspace.exists())
1217             black.write_cache({}, [], mode)
1218             self.assertTrue(workspace.exists())
1219
1220     @event_loop()
1221     def test_failed_formatting_does_not_get_cached(self) -> None:
1222         mode = DEFAULT_MODE
1223         with cache_dir() as workspace, patch(
1224             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1225         ):
1226             failing = (workspace / "failing.py").resolve()
1227             with failing.open("w") as fobj:
1228                 fobj.write("not actually python")
1229             clean = (workspace / "clean.py").resolve()
1230             with clean.open("w") as fobj:
1231                 fobj.write('print("hello")\n')
1232             self.invokeBlack([str(workspace)], exit_code=123)
1233             cache = black.read_cache(mode)
1234             self.assertNotIn(str(failing), cache)
1235             self.assertIn(str(clean), cache)
1236
1237     def test_write_cache_write_fail(self) -> None:
1238         mode = DEFAULT_MODE
1239         with cache_dir(), patch.object(Path, "open") as mock:
1240             mock.side_effect = OSError
1241             black.write_cache({}, [], mode)
1242
1243     @event_loop()
1244     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1245     def test_works_in_mono_process_only_environment(self) -> None:
1246         with cache_dir() as workspace:
1247             for f in [
1248                 (workspace / "one.py").resolve(),
1249                 (workspace / "two.py").resolve(),
1250             ]:
1251                 f.write_text('print("hello")\n')
1252             self.invokeBlack([str(workspace)])
1253
1254     @event_loop()
1255     def test_check_diff_use_together(self) -> None:
1256         with cache_dir():
1257             # Files which will be reformatted.
1258             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1259             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1260             # Files which will not be reformatted.
1261             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1262             self.invokeBlack([str(src2), "--diff", "--check"])
1263             # Multi file command.
1264             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1265
1266     def test_no_files(self) -> None:
1267         with cache_dir():
1268             # Without an argument, black exits with error code 0.
1269             self.invokeBlack([])
1270
1271     def test_broken_symlink(self) -> None:
1272         with cache_dir() as workspace:
1273             symlink = workspace / "broken_link.py"
1274             try:
1275                 symlink.symlink_to("nonexistent.py")
1276             except OSError as e:
1277                 self.skipTest(f"Can't create symlinks: {e}")
1278             self.invokeBlack([str(workspace.resolve())])
1279
1280     def test_read_cache_line_lengths(self) -> None:
1281         mode = DEFAULT_MODE
1282         short_mode = replace(DEFAULT_MODE, line_length=1)
1283         with cache_dir() as workspace:
1284             path = (workspace / "file.py").resolve()
1285             path.touch()
1286             black.write_cache({}, [path], mode)
1287             one = black.read_cache(mode)
1288             self.assertIn(str(path), one)
1289             two = black.read_cache(short_mode)
1290             self.assertNotIn(str(path), two)
1291
1292     def test_single_file_force_pyi(self) -> None:
1293         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1294         contents, expected = read_data("force_pyi")
1295         with cache_dir() as workspace:
1296             path = (workspace / "file.py").resolve()
1297             with open(path, "w") as fh:
1298                 fh.write(contents)
1299             self.invokeBlack([str(path), "--pyi"])
1300             with open(path, "r") as fh:
1301                 actual = fh.read()
1302             # verify cache with --pyi is separate
1303             pyi_cache = black.read_cache(pyi_mode)
1304             self.assertIn(str(path), pyi_cache)
1305             normal_cache = black.read_cache(DEFAULT_MODE)
1306             self.assertNotIn(str(path), normal_cache)
1307         self.assertFormatEqual(expected, actual)
1308         black.assert_equivalent(contents, actual)
1309         black.assert_stable(contents, actual, pyi_mode)
1310
1311     @event_loop()
1312     def test_multi_file_force_pyi(self) -> None:
1313         reg_mode = DEFAULT_MODE
1314         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1315         contents, expected = read_data("force_pyi")
1316         with cache_dir() as workspace:
1317             paths = [
1318                 (workspace / "file1.py").resolve(),
1319                 (workspace / "file2.py").resolve(),
1320             ]
1321             for path in paths:
1322                 with open(path, "w") as fh:
1323                     fh.write(contents)
1324             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1325             for path in paths:
1326                 with open(path, "r") as fh:
1327                     actual = fh.read()
1328                 self.assertEqual(actual, expected)
1329             # verify cache with --pyi is separate
1330             pyi_cache = black.read_cache(pyi_mode)
1331             normal_cache = black.read_cache(reg_mode)
1332             for path in paths:
1333                 self.assertIn(str(path), pyi_cache)
1334                 self.assertNotIn(str(path), normal_cache)
1335
1336     def test_pipe_force_pyi(self) -> None:
1337         source, expected = read_data("force_pyi")
1338         result = CliRunner().invoke(
1339             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1340         )
1341         self.assertEqual(result.exit_code, 0)
1342         actual = result.output
1343         self.assertFormatEqual(actual, expected)
1344
1345     def test_single_file_force_py36(self) -> None:
1346         reg_mode = DEFAULT_MODE
1347         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1348         source, expected = read_data("force_py36")
1349         with cache_dir() as workspace:
1350             path = (workspace / "file.py").resolve()
1351             with open(path, "w") as fh:
1352                 fh.write(source)
1353             self.invokeBlack([str(path), *PY36_ARGS])
1354             with open(path, "r") as fh:
1355                 actual = fh.read()
1356             # verify cache with --target-version is separate
1357             py36_cache = black.read_cache(py36_mode)
1358             self.assertIn(str(path), py36_cache)
1359             normal_cache = black.read_cache(reg_mode)
1360             self.assertNotIn(str(path), normal_cache)
1361         self.assertEqual(actual, expected)
1362
1363     @event_loop()
1364     def test_multi_file_force_py36(self) -> None:
1365         reg_mode = DEFAULT_MODE
1366         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1367         source, expected = read_data("force_py36")
1368         with cache_dir() as workspace:
1369             paths = [
1370                 (workspace / "file1.py").resolve(),
1371                 (workspace / "file2.py").resolve(),
1372             ]
1373             for path in paths:
1374                 with open(path, "w") as fh:
1375                     fh.write(source)
1376             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1377             for path in paths:
1378                 with open(path, "r") as fh:
1379                     actual = fh.read()
1380                 self.assertEqual(actual, expected)
1381             # verify cache with --target-version is separate
1382             pyi_cache = black.read_cache(py36_mode)
1383             normal_cache = black.read_cache(reg_mode)
1384             for path in paths:
1385                 self.assertIn(str(path), pyi_cache)
1386                 self.assertNotIn(str(path), normal_cache)
1387
1388     def test_pipe_force_py36(self) -> None:
1389         source, expected = read_data("force_py36")
1390         result = CliRunner().invoke(
1391             black.main,
1392             ["-", "-q", "--target-version=py36"],
1393             input=BytesIO(source.encode("utf8")),
1394         )
1395         self.assertEqual(result.exit_code, 0)
1396         actual = result.output
1397         self.assertFormatEqual(actual, expected)
1398
1399     def test_include_exclude(self) -> None:
1400         path = THIS_DIR / "data" / "include_exclude_tests"
1401         include = re.compile(r"\.pyi?$")
1402         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1403         report = black.Report()
1404         gitignore = PathSpec.from_lines("gitwildmatch", [])
1405         sources: List[Path] = []
1406         expected = [
1407             Path(path / "b/dont_exclude/a.py"),
1408             Path(path / "b/dont_exclude/a.pyi"),
1409         ]
1410         this_abs = THIS_DIR.resolve()
1411         sources.extend(
1412             black.gen_python_files(
1413                 path.iterdir(),
1414                 this_abs,
1415                 include,
1416                 exclude,
1417                 None,
1418                 None,
1419                 report,
1420                 gitignore,
1421             )
1422         )
1423         self.assertEqual(sorted(expected), sorted(sources))
1424
1425     def test_gitingore_used_as_default(self) -> None:
1426         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1427         include = re.compile(r"\.pyi?$")
1428         extend_exclude = re.compile(r"/exclude/")
1429         src = str(path / "b/")
1430         report = black.Report()
1431         expected: List[Path] = [
1432             path / "b/.definitely_exclude/a.py",
1433             path / "b/.definitely_exclude/a.pyi",
1434         ]
1435         sources = list(
1436             black.get_sources(
1437                 ctx=FakeContext(),
1438                 src=(src,),
1439                 quiet=True,
1440                 verbose=False,
1441                 include=include,
1442                 exclude=None,
1443                 extend_exclude=extend_exclude,
1444                 force_exclude=None,
1445                 report=report,
1446                 stdin_filename=None,
1447             )
1448         )
1449         self.assertEqual(sorted(expected), sorted(sources))
1450
1451     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1452     def test_exclude_for_issue_1572(self) -> None:
1453         # Exclude shouldn't touch files that were explicitly given to Black through the
1454         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1455         # https://github.com/psf/black/issues/1572
1456         path = THIS_DIR / "data" / "include_exclude_tests"
1457         include = ""
1458         exclude = r"/exclude/|a\.py"
1459         src = str(path / "b/exclude/a.py")
1460         report = black.Report()
1461         expected = [Path(path / "b/exclude/a.py")]
1462         sources = list(
1463             black.get_sources(
1464                 ctx=FakeContext(),
1465                 src=(src,),
1466                 quiet=True,
1467                 verbose=False,
1468                 include=re.compile(include),
1469                 exclude=re.compile(exclude),
1470                 extend_exclude=None,
1471                 force_exclude=None,
1472                 report=report,
1473                 stdin_filename=None,
1474             )
1475         )
1476         self.assertEqual(sorted(expected), sorted(sources))
1477
1478     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1479     def test_get_sources_with_stdin(self) -> None:
1480         include = ""
1481         exclude = r"/exclude/|a\.py"
1482         src = "-"
1483         report = black.Report()
1484         expected = [Path("-")]
1485         sources = list(
1486             black.get_sources(
1487                 ctx=FakeContext(),
1488                 src=(src,),
1489                 quiet=True,
1490                 verbose=False,
1491                 include=re.compile(include),
1492                 exclude=re.compile(exclude),
1493                 extend_exclude=None,
1494                 force_exclude=None,
1495                 report=report,
1496                 stdin_filename=None,
1497             )
1498         )
1499         self.assertEqual(sorted(expected), sorted(sources))
1500
1501     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1502     def test_get_sources_with_stdin_filename(self) -> None:
1503         include = ""
1504         exclude = r"/exclude/|a\.py"
1505         src = "-"
1506         report = black.Report()
1507         stdin_filename = str(THIS_DIR / "data/collections.py")
1508         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1509         sources = list(
1510             black.get_sources(
1511                 ctx=FakeContext(),
1512                 src=(src,),
1513                 quiet=True,
1514                 verbose=False,
1515                 include=re.compile(include),
1516                 exclude=re.compile(exclude),
1517                 extend_exclude=None,
1518                 force_exclude=None,
1519                 report=report,
1520                 stdin_filename=stdin_filename,
1521             )
1522         )
1523         self.assertEqual(sorted(expected), sorted(sources))
1524
1525     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1526     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1527         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1528         # file being passed directly. This is the same as
1529         # test_exclude_for_issue_1572
1530         path = THIS_DIR / "data" / "include_exclude_tests"
1531         include = ""
1532         exclude = r"/exclude/|a\.py"
1533         src = "-"
1534         report = black.Report()
1535         stdin_filename = str(path / "b/exclude/a.py")
1536         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1537         sources = list(
1538             black.get_sources(
1539                 ctx=FakeContext(),
1540                 src=(src,),
1541                 quiet=True,
1542                 verbose=False,
1543                 include=re.compile(include),
1544                 exclude=re.compile(exclude),
1545                 extend_exclude=None,
1546                 force_exclude=None,
1547                 report=report,
1548                 stdin_filename=stdin_filename,
1549             )
1550         )
1551         self.assertEqual(sorted(expected), sorted(sources))
1552
1553     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1554     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1555         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1556         # file being passed directly. This is the same as
1557         # test_exclude_for_issue_1572
1558         path = THIS_DIR / "data" / "include_exclude_tests"
1559         include = ""
1560         extend_exclude = r"/exclude/|a\.py"
1561         src = "-"
1562         report = black.Report()
1563         stdin_filename = str(path / "b/exclude/a.py")
1564         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1565         sources = list(
1566             black.get_sources(
1567                 ctx=FakeContext(),
1568                 src=(src,),
1569                 quiet=True,
1570                 verbose=False,
1571                 include=re.compile(include),
1572                 exclude=re.compile(""),
1573                 extend_exclude=re.compile(extend_exclude),
1574                 force_exclude=None,
1575                 report=report,
1576                 stdin_filename=stdin_filename,
1577             )
1578         )
1579         self.assertEqual(sorted(expected), sorted(sources))
1580
1581     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1582     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1583         # Force exclude should exclude the file when passing it through
1584         # stdin_filename
1585         path = THIS_DIR / "data" / "include_exclude_tests"
1586         include = ""
1587         force_exclude = r"/exclude/|a\.py"
1588         src = "-"
1589         report = black.Report()
1590         stdin_filename = str(path / "b/exclude/a.py")
1591         sources = list(
1592             black.get_sources(
1593                 ctx=FakeContext(),
1594                 src=(src,),
1595                 quiet=True,
1596                 verbose=False,
1597                 include=re.compile(include),
1598                 exclude=re.compile(""),
1599                 extend_exclude=None,
1600                 force_exclude=re.compile(force_exclude),
1601                 report=report,
1602                 stdin_filename=stdin_filename,
1603             )
1604         )
1605         self.assertEqual([], sorted(sources))
1606
1607     def test_reformat_one_with_stdin(self) -> None:
1608         with patch(
1609             "black.format_stdin_to_stdout",
1610             return_value=lambda *args, **kwargs: black.Changed.YES,
1611         ) as fsts:
1612             report = MagicMock()
1613             path = Path("-")
1614             black.reformat_one(
1615                 path,
1616                 fast=True,
1617                 write_back=black.WriteBack.YES,
1618                 mode=DEFAULT_MODE,
1619                 report=report,
1620             )
1621             fsts.assert_called_once()
1622             report.done.assert_called_with(path, black.Changed.YES)
1623
1624     def test_reformat_one_with_stdin_filename(self) -> None:
1625         with patch(
1626             "black.format_stdin_to_stdout",
1627             return_value=lambda *args, **kwargs: black.Changed.YES,
1628         ) as fsts:
1629             report = MagicMock()
1630             p = "foo.py"
1631             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1632             expected = Path(p)
1633             black.reformat_one(
1634                 path,
1635                 fast=True,
1636                 write_back=black.WriteBack.YES,
1637                 mode=DEFAULT_MODE,
1638                 report=report,
1639             )
1640             fsts.assert_called_once_with(
1641                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1642             )
1643             # __BLACK_STDIN_FILENAME__ should have been stripped
1644             report.done.assert_called_with(expected, black.Changed.YES)
1645
1646     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1647         with patch(
1648             "black.format_stdin_to_stdout",
1649             return_value=lambda *args, **kwargs: black.Changed.YES,
1650         ) as fsts:
1651             report = MagicMock()
1652             p = "foo.pyi"
1653             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1654             expected = Path(p)
1655             black.reformat_one(
1656                 path,
1657                 fast=True,
1658                 write_back=black.WriteBack.YES,
1659                 mode=DEFAULT_MODE,
1660                 report=report,
1661             )
1662             fsts.assert_called_once_with(
1663                 fast=True,
1664                 write_back=black.WriteBack.YES,
1665                 mode=replace(DEFAULT_MODE, is_pyi=True),
1666             )
1667             # __BLACK_STDIN_FILENAME__ should have been stripped
1668             report.done.assert_called_with(expected, black.Changed.YES)
1669
1670     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1671         with patch(
1672             "black.format_stdin_to_stdout",
1673             return_value=lambda *args, **kwargs: black.Changed.YES,
1674         ) as fsts:
1675             report = MagicMock()
1676             # Even with an existing file, since we are forcing stdin, black
1677             # should output to stdout and not modify the file inplace
1678             p = Path(str(THIS_DIR / "data/collections.py"))
1679             # Make sure is_file actually returns True
1680             self.assertTrue(p.is_file())
1681             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1682             expected = Path(p)
1683             black.reformat_one(
1684                 path,
1685                 fast=True,
1686                 write_back=black.WriteBack.YES,
1687                 mode=DEFAULT_MODE,
1688                 report=report,
1689             )
1690             fsts.assert_called_once()
1691             # __BLACK_STDIN_FILENAME__ should have been stripped
1692             report.done.assert_called_with(expected, black.Changed.YES)
1693
1694     def test_gitignore_exclude(self) -> None:
1695         path = THIS_DIR / "data" / "include_exclude_tests"
1696         include = re.compile(r"\.pyi?$")
1697         exclude = re.compile(r"")
1698         report = black.Report()
1699         gitignore = PathSpec.from_lines(
1700             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1701         )
1702         sources: List[Path] = []
1703         expected = [
1704             Path(path / "b/dont_exclude/a.py"),
1705             Path(path / "b/dont_exclude/a.pyi"),
1706         ]
1707         this_abs = THIS_DIR.resolve()
1708         sources.extend(
1709             black.gen_python_files(
1710                 path.iterdir(),
1711                 this_abs,
1712                 include,
1713                 exclude,
1714                 None,
1715                 None,
1716                 report,
1717                 gitignore,
1718             )
1719         )
1720         self.assertEqual(sorted(expected), sorted(sources))
1721
1722     def test_empty_include(self) -> None:
1723         path = THIS_DIR / "data" / "include_exclude_tests"
1724         report = black.Report()
1725         gitignore = PathSpec.from_lines("gitwildmatch", [])
1726         empty = re.compile(r"")
1727         sources: List[Path] = []
1728         expected = [
1729             Path(path / "b/exclude/a.pie"),
1730             Path(path / "b/exclude/a.py"),
1731             Path(path / "b/exclude/a.pyi"),
1732             Path(path / "b/dont_exclude/a.pie"),
1733             Path(path / "b/dont_exclude/a.py"),
1734             Path(path / "b/dont_exclude/a.pyi"),
1735             Path(path / "b/.definitely_exclude/a.pie"),
1736             Path(path / "b/.definitely_exclude/a.py"),
1737             Path(path / "b/.definitely_exclude/a.pyi"),
1738             Path(path / ".gitignore"),
1739             Path(path / "pyproject.toml"),
1740         ]
1741         this_abs = THIS_DIR.resolve()
1742         sources.extend(
1743             black.gen_python_files(
1744                 path.iterdir(),
1745                 this_abs,
1746                 empty,
1747                 re.compile(black.DEFAULT_EXCLUDES),
1748                 None,
1749                 None,
1750                 report,
1751                 gitignore,
1752             )
1753         )
1754         self.assertEqual(sorted(expected), sorted(sources))
1755
1756     def test_extend_exclude(self) -> None:
1757         path = THIS_DIR / "data" / "include_exclude_tests"
1758         report = black.Report()
1759         gitignore = PathSpec.from_lines("gitwildmatch", [])
1760         sources: List[Path] = []
1761         expected = [
1762             Path(path / "b/exclude/a.py"),
1763             Path(path / "b/dont_exclude/a.py"),
1764         ]
1765         this_abs = THIS_DIR.resolve()
1766         sources.extend(
1767             black.gen_python_files(
1768                 path.iterdir(),
1769                 this_abs,
1770                 re.compile(black.DEFAULT_INCLUDES),
1771                 re.compile(r"\.pyi$"),
1772                 re.compile(r"\.definitely_exclude"),
1773                 None,
1774                 report,
1775                 gitignore,
1776             )
1777         )
1778         self.assertEqual(sorted(expected), sorted(sources))
1779
1780     def test_invalid_cli_regex(self) -> None:
1781         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1782             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1783
1784     def test_preserves_line_endings(self) -> None:
1785         with TemporaryDirectory() as workspace:
1786             test_file = Path(workspace) / "test.py"
1787             for nl in ["\n", "\r\n"]:
1788                 contents = nl.join(["def f(  ):", "    pass"])
1789                 test_file.write_bytes(contents.encode())
1790                 ff(test_file, write_back=black.WriteBack.YES)
1791                 updated_contents: bytes = test_file.read_bytes()
1792                 self.assertIn(nl.encode(), updated_contents)
1793                 if nl == "\n":
1794                     self.assertNotIn(b"\r\n", updated_contents)
1795
1796     def test_preserves_line_endings_via_stdin(self) -> None:
1797         for nl in ["\n", "\r\n"]:
1798             contents = nl.join(["def f(  ):", "    pass"])
1799             runner = BlackRunner()
1800             result = runner.invoke(
1801                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1802             )
1803             self.assertEqual(result.exit_code, 0)
1804             output = runner.stdout_bytes
1805             self.assertIn(nl.encode("utf8"), output)
1806             if nl == "\n":
1807                 self.assertNotIn(b"\r\n", output)
1808
1809     def test_assert_equivalent_different_asts(self) -> None:
1810         with self.assertRaises(AssertionError):
1811             black.assert_equivalent("{}", "None")
1812
1813     def test_symlink_out_of_root_directory(self) -> None:
1814         path = MagicMock()
1815         root = THIS_DIR.resolve()
1816         child = MagicMock()
1817         include = re.compile(black.DEFAULT_INCLUDES)
1818         exclude = re.compile(black.DEFAULT_EXCLUDES)
1819         report = black.Report()
1820         gitignore = PathSpec.from_lines("gitwildmatch", [])
1821         # `child` should behave like a symlink which resolved path is clearly
1822         # outside of the `root` directory.
1823         path.iterdir.return_value = [child]
1824         child.resolve.return_value = Path("/a/b/c")
1825         child.as_posix.return_value = "/a/b/c"
1826         child.is_symlink.return_value = True
1827         try:
1828             list(
1829                 black.gen_python_files(
1830                     path.iterdir(),
1831                     root,
1832                     include,
1833                     exclude,
1834                     None,
1835                     None,
1836                     report,
1837                     gitignore,
1838                 )
1839             )
1840         except ValueError as ve:
1841             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1842         path.iterdir.assert_called_once()
1843         child.resolve.assert_called_once()
1844         child.is_symlink.assert_called_once()
1845         # `child` should behave like a strange file which resolved path is clearly
1846         # outside of the `root` directory.
1847         child.is_symlink.return_value = False
1848         with self.assertRaises(ValueError):
1849             list(
1850                 black.gen_python_files(
1851                     path.iterdir(),
1852                     root,
1853                     include,
1854                     exclude,
1855                     None,
1856                     None,
1857                     report,
1858                     gitignore,
1859                 )
1860             )
1861         path.iterdir.assert_called()
1862         self.assertEqual(path.iterdir.call_count, 2)
1863         child.resolve.assert_called()
1864         self.assertEqual(child.resolve.call_count, 2)
1865         child.is_symlink.assert_called()
1866         self.assertEqual(child.is_symlink.call_count, 2)
1867
1868     def test_shhh_click(self) -> None:
1869         try:
1870             from click import _unicodefun  # type: ignore
1871         except ModuleNotFoundError:
1872             self.skipTest("Incompatible Click version")
1873         if not hasattr(_unicodefun, "_verify_python3_env"):
1874             self.skipTest("Incompatible Click version")
1875         # First, let's see if Click is crashing with a preferred ASCII charset.
1876         with patch("locale.getpreferredencoding") as gpe:
1877             gpe.return_value = "ASCII"
1878             with self.assertRaises(RuntimeError):
1879                 _unicodefun._verify_python3_env()
1880         # Now, let's silence Click...
1881         black.patch_click()
1882         # ...and confirm it's silent.
1883         with patch("locale.getpreferredencoding") as gpe:
1884             gpe.return_value = "ASCII"
1885             try:
1886                 _unicodefun._verify_python3_env()
1887             except RuntimeError as re:
1888                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1889
1890     def test_root_logger_not_used_directly(self) -> None:
1891         def fail(*args: Any, **kwargs: Any) -> None:
1892             self.fail("Record created with root logger")
1893
1894         with patch.multiple(
1895             logging.root,
1896             debug=fail,
1897             info=fail,
1898             warning=fail,
1899             error=fail,
1900             critical=fail,
1901             log=fail,
1902         ):
1903             ff(THIS_FILE)
1904
1905     def test_invalid_config_return_code(self) -> None:
1906         tmp_file = Path(black.dump_to_file())
1907         try:
1908             tmp_config = Path(black.dump_to_file())
1909             tmp_config.unlink()
1910             args = ["--config", str(tmp_config), str(tmp_file)]
1911             self.invokeBlack(args, exit_code=2, ignore_config=False)
1912         finally:
1913             tmp_file.unlink()
1914
1915     def test_parse_pyproject_toml(self) -> None:
1916         test_toml_file = THIS_DIR / "test.toml"
1917         config = black.parse_pyproject_toml(str(test_toml_file))
1918         self.assertEqual(config["verbose"], 1)
1919         self.assertEqual(config["check"], "no")
1920         self.assertEqual(config["diff"], "y")
1921         self.assertEqual(config["color"], True)
1922         self.assertEqual(config["line_length"], 79)
1923         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1924         self.assertEqual(config["exclude"], r"\.pyi?$")
1925         self.assertEqual(config["include"], r"\.py?$")
1926
1927     def test_read_pyproject_toml(self) -> None:
1928         test_toml_file = THIS_DIR / "test.toml"
1929         fake_ctx = FakeContext()
1930         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1931         config = fake_ctx.default_map
1932         self.assertEqual(config["verbose"], "1")
1933         self.assertEqual(config["check"], "no")
1934         self.assertEqual(config["diff"], "y")
1935         self.assertEqual(config["color"], "True")
1936         self.assertEqual(config["line_length"], "79")
1937         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1938         self.assertEqual(config["exclude"], r"\.pyi?$")
1939         self.assertEqual(config["include"], r"\.py?$")
1940
1941     def test_find_project_root(self) -> None:
1942         with TemporaryDirectory() as workspace:
1943             root = Path(workspace)
1944             test_dir = root / "test"
1945             test_dir.mkdir()
1946
1947             src_dir = root / "src"
1948             src_dir.mkdir()
1949
1950             root_pyproject = root / "pyproject.toml"
1951             root_pyproject.touch()
1952             src_pyproject = src_dir / "pyproject.toml"
1953             src_pyproject.touch()
1954             src_python = src_dir / "foo.py"
1955             src_python.touch()
1956
1957             self.assertEqual(
1958                 black.find_project_root((src_dir, test_dir)), root.resolve()
1959             )
1960             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1961             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1962
1963     @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
1964     def test_find_user_pyproject_toml_linux(self) -> None:
1965         if system() == "Windows":
1966             return
1967
1968         # Test if XDG_CONFIG_HOME is checked
1969         with TemporaryDirectory() as workspace:
1970             tmp_user_config = Path(workspace) / "black"
1971             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1972                 self.assertEqual(
1973                     black.find_user_pyproject_toml(), tmp_user_config.resolve()
1974                 )
1975
1976         # Test fallback for XDG_CONFIG_HOME
1977         with patch.dict("os.environ"):
1978             os.environ.pop("XDG_CONFIG_HOME", None)
1979             fallback_user_config = Path("~/.config").expanduser() / "black"
1980             self.assertEqual(
1981                 black.find_user_pyproject_toml(), fallback_user_config.resolve()
1982             )
1983
1984     def test_find_user_pyproject_toml_windows(self) -> None:
1985         if system() != "Windows":
1986             return
1987
1988         user_config_path = Path.home() / ".black"
1989         self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
1990
1991     def test_bpo_33660_workaround(self) -> None:
1992         if system() == "Windows":
1993             return
1994
1995         # https://bugs.python.org/issue33660
1996
1997         old_cwd = Path.cwd()
1998         try:
1999             root = Path("/")
2000             os.chdir(str(root))
2001             path = Path("workspace") / "project"
2002             report = black.Report(verbose=True)
2003             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2004             self.assertEqual(normalized_path, "workspace/project")
2005         finally:
2006             os.chdir(str(old_cwd))
2007
2008     def test_newline_comment_interaction(self) -> None:
2009         source = "class A:\\\r\n# type: ignore\n pass\n"
2010         output = black.format_str(source, mode=DEFAULT_MODE)
2011         black.assert_stable(source, output, mode=DEFAULT_MODE)
2012
2013     def test_bpo_2142_workaround(self) -> None:
2014
2015         # https://bugs.python.org/issue2142
2016
2017         source, _ = read_data("missing_final_newline.py")
2018         # read_data adds a trailing newline
2019         source = source.rstrip()
2020         expected, _ = read_data("missing_final_newline.diff")
2021         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2022         diff_header = re.compile(
2023             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2024             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2025         )
2026         try:
2027             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2028             self.assertEqual(result.exit_code, 0)
2029         finally:
2030             os.unlink(tmp_file)
2031         actual = result.output
2032         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2033         self.assertEqual(actual, expected)
2034
2035     @pytest.mark.python2
2036     def test_docstring_reformat_for_py27(self) -> None:
2037         """
2038         Check that stripping trailing whitespace from Python 2 docstrings
2039         doesn't trigger a "not equivalent to source" error
2040         """
2041         source = (
2042             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2043         )
2044         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2045
2046         result = CliRunner().invoke(
2047             black.main,
2048             ["-", "-q", "--target-version=py27"],
2049             input=BytesIO(source),
2050         )
2051
2052         self.assertEqual(result.exit_code, 0)
2053         actual = result.output
2054         self.assertFormatEqual(actual, expected)
2055
2056
2057 with open(black.__file__, "r", encoding="utf-8") as _bf:
2058     black_source_lines = _bf.readlines()
2059
2060
2061 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2062     """Show function calls `from black/__init__.py` as they happen.
2063
2064     Register this with `sys.settrace()` in a test you're debugging.
2065     """
2066     if event != "call":
2067         return tracefunc
2068
2069     stack = len(inspect.stack()) - 19
2070     stack *= 2
2071     filename = frame.f_code.co_filename
2072     lineno = frame.f_lineno
2073     func_sig_lineno = lineno - 1
2074     funcname = black_source_lines[func_sig_lineno].strip()
2075     while funcname.startswith("@"):
2076         func_sig_lineno += 1
2077         funcname = black_source_lines[func_sig_lineno].strip()
2078     if "black/__init__.py" in filename:
2079         print(f"{' ' * stack}{lineno}:{funcname}")
2080     return tracefunc
2081
2082
2083 if __name__ == "__main__":
2084     unittest.main(module="test_black")