]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Cover more in the usage docs (#2208)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import pytest
28 import unittest
29 from unittest.mock import patch, MagicMock
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37 from black.cache import get_cache_file
38 from black.debug import DebugVisitor
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     read_data,
48     DETERMINISTIC_HEADER,
49     BlackBaseTestCase,
50     DEFAULT_MODE,
51     fs,
52     ff,
53     dump_to_stderr,
54 )
55 from .test_primer import PrimerCLITests  # noqa: F401
56
57
58 THIS_FILE = Path(__file__)
59 PY36_VERSIONS = {
60     TargetVersion.PY36,
61     TargetVersion.PY37,
62     TargetVersion.PY38,
63     TargetVersion.PY39,
64 }
65 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
66 T = TypeVar("T")
67 R = TypeVar("R")
68
69
70 @contextmanager
71 def cache_dir(exists: bool = True) -> Iterator[Path]:
72     with TemporaryDirectory() as workspace:
73         cache_dir = Path(workspace)
74         if not exists:
75             cache_dir = cache_dir / "new"
76         with patch("black.cache.CACHE_DIR", cache_dir):
77             yield cache_dir
78
79
80 @contextmanager
81 def event_loop() -> Iterator[None]:
82     policy = asyncio.get_event_loop_policy()
83     loop = policy.new_event_loop()
84     asyncio.set_event_loop(loop)
85     try:
86         yield
87
88     finally:
89         loop.close()
90
91
92 class FakeContext(click.Context):
93     """A fake click Context for when calling functions that need it."""
94
95     def __init__(self) -> None:
96         self.default_map: Dict[str, Any] = {}
97
98
99 class FakeParameter(click.Parameter):
100     """A fake click Parameter for when calling functions that need it."""
101
102     def __init__(self) -> None:
103         pass
104
105
106 class BlackRunner(CliRunner):
107     """Modify CliRunner so that stderr is not merged with stdout.
108
109     This is a hack that can be removed once we depend on Click 7.x"""
110
111     def __init__(self) -> None:
112         self.stderrbuf = BytesIO()
113         self.stdoutbuf = BytesIO()
114         self.stdout_bytes = b""
115         self.stderr_bytes = b""
116         super().__init__()
117
118     @contextmanager
119     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
120         with super().isolation(*args, **kwargs) as output:
121             try:
122                 hold_stderr = sys.stderr
123                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
124                 yield output
125             finally:
126                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
127                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
128                 sys.stderr = hold_stderr
129
130
131 class BlackTestCase(BlackBaseTestCase):
132     def invokeBlack(
133         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
134     ) -> None:
135         runner = BlackRunner()
136         if ignore_config:
137             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
138         result = runner.invoke(black.main, args)
139         self.assertEqual(
140             result.exit_code,
141             exit_code,
142             msg=(
143                 f"Failed with args: {args}\n"
144                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
145                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
146                 f"exception: {result.exception}"
147             ),
148         )
149
150     @patch("black.dump_to_file", dump_to_stderr)
151     def test_empty(self) -> None:
152         source = expected = ""
153         actual = fs(source)
154         self.assertFormatEqual(expected, actual)
155         black.assert_equivalent(source, actual)
156         black.assert_stable(source, actual, DEFAULT_MODE)
157
158     def test_empty_ff(self) -> None:
159         expected = ""
160         tmp_file = Path(black.dump_to_file())
161         try:
162             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
163             with open(tmp_file, encoding="utf8") as f:
164                 actual = f.read()
165         finally:
166             os.unlink(tmp_file)
167         self.assertFormatEqual(expected, actual)
168
169     def test_piping(self) -> None:
170         source, expected = read_data("src/black/__init__", data=False)
171         result = BlackRunner().invoke(
172             black.main,
173             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
174             input=BytesIO(source.encode("utf8")),
175         )
176         self.assertEqual(result.exit_code, 0)
177         self.assertFormatEqual(expected, result.output)
178         if source != result.output:
179             black.assert_equivalent(source, result.output)
180             black.assert_stable(source, result.output, DEFAULT_MODE)
181
182     def test_piping_diff(self) -> None:
183         diff_header = re.compile(
184             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
185             r"\+\d\d\d\d"
186         )
187         source, _ = read_data("expression.py")
188         expected, _ = read_data("expression.diff")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             f"--config={config}",
196         ]
197         result = BlackRunner().invoke(
198             black.main, args, input=BytesIO(source.encode("utf8"))
199         )
200         self.assertEqual(result.exit_code, 0)
201         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
202         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
203         self.assertEqual(expected, actual)
204
205     def test_piping_diff_with_color(self) -> None:
206         source, _ = read_data("expression.py")
207         config = THIS_DIR / "data" / "empty_pyproject.toml"
208         args = [
209             "-",
210             "--fast",
211             f"--line-length={black.DEFAULT_LINE_LENGTH}",
212             "--diff",
213             "--color",
214             f"--config={config}",
215         ]
216         result = BlackRunner().invoke(
217             black.main, args, input=BytesIO(source.encode("utf8"))
218         )
219         actual = result.output
220         # Again, the contents are checked in a different test, so only look for colors.
221         self.assertIn("\033[1;37m", actual)
222         self.assertIn("\033[36m", actual)
223         self.assertIn("\033[32m", actual)
224         self.assertIn("\033[31m", actual)
225         self.assertIn("\033[0m", actual)
226
227     @patch("black.dump_to_file", dump_to_stderr)
228     def _test_wip(self) -> None:
229         source, expected = read_data("wip")
230         sys.settrace(tracefunc)
231         mode = replace(
232             DEFAULT_MODE,
233             experimental_string_processing=False,
234             target_versions={black.TargetVersion.PY38},
235         )
236         actual = fs(source, mode=mode)
237         sys.settrace(None)
238         self.assertFormatEqual(expected, actual)
239         black.assert_equivalent(source, actual)
240         black.assert_stable(source, actual, black.FileMode())
241
242     @unittest.expectedFailure
243     @patch("black.dump_to_file", dump_to_stderr)
244     def test_trailing_comma_optional_parens_stability1(self) -> None:
245         source, _expected = read_data("trailing_comma_optional_parens1")
246         actual = fs(source)
247         black.assert_stable(source, actual, DEFAULT_MODE)
248
249     @unittest.expectedFailure
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_trailing_comma_optional_parens_stability2(self) -> None:
252         source, _expected = read_data("trailing_comma_optional_parens2")
253         actual = fs(source)
254         black.assert_stable(source, actual, DEFAULT_MODE)
255
256     @unittest.expectedFailure
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(source)
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     @patch("black.dump_to_file", dump_to_stderr)
264     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
265         source, _expected = read_data("trailing_comma_optional_parens1")
266         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
267         black.assert_stable(source, actual, DEFAULT_MODE)
268
269     @patch("black.dump_to_file", dump_to_stderr)
270     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
271         source, _expected = read_data("trailing_comma_optional_parens2")
272         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
273         black.assert_stable(source, actual, DEFAULT_MODE)
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
277         source, _expected = read_data("trailing_comma_optional_parens3")
278         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
279         black.assert_stable(source, actual, DEFAULT_MODE)
280
281     @patch("black.dump_to_file", dump_to_stderr)
282     def test_pep_572(self) -> None:
283         source, expected = read_data("pep_572")
284         actual = fs(source)
285         self.assertFormatEqual(expected, actual)
286         black.assert_stable(source, actual, DEFAULT_MODE)
287         if sys.version_info >= (3, 8):
288             black.assert_equivalent(source, actual)
289
290     @patch("black.dump_to_file", dump_to_stderr)
291     def test_pep_572_remove_parens(self) -> None:
292         source, expected = read_data("pep_572_remove_parens")
293         actual = fs(source)
294         self.assertFormatEqual(expected, actual)
295         black.assert_stable(source, actual, DEFAULT_MODE)
296         if sys.version_info >= (3, 8):
297             black.assert_equivalent(source, actual)
298
299     @patch("black.dump_to_file", dump_to_stderr)
300     def test_pep_572_do_not_remove_parens(self) -> None:
301         source, expected = read_data("pep_572_do_not_remove_parens")
302         # the AST safety checks will fail, but that's expected, just make sure no
303         # parentheses are touched
304         actual = black.format_str(source, mode=DEFAULT_MODE)
305         self.assertFormatEqual(expected, actual)
306
307     def test_pep_572_version_detection(self) -> None:
308         source, _ = read_data("pep_572")
309         root = black.lib2to3_parse(source)
310         features = black.get_features_used(root)
311         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
312         versions = black.detect_target_versions(root)
313         self.assertIn(black.TargetVersion.PY38, versions)
314
315     def test_expression_ff(self) -> None:
316         source, expected = read_data("expression")
317         tmp_file = Path(black.dump_to_file(source))
318         try:
319             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
320             with open(tmp_file, encoding="utf8") as f:
321                 actual = f.read()
322         finally:
323             os.unlink(tmp_file)
324         self.assertFormatEqual(expected, actual)
325         with patch("black.dump_to_file", dump_to_stderr):
326             black.assert_equivalent(source, actual)
327             black.assert_stable(source, actual, DEFAULT_MODE)
328
329     def test_expression_diff(self) -> None:
330         source, _ = read_data("expression.py")
331         config = THIS_DIR / "data" / "empty_pyproject.toml"
332         expected, _ = read_data("expression.diff")
333         tmp_file = Path(black.dump_to_file(source))
334         diff_header = re.compile(
335             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
336             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
337         )
338         try:
339             result = BlackRunner().invoke(
340                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
341             )
342             self.assertEqual(result.exit_code, 0)
343         finally:
344             os.unlink(tmp_file)
345         actual = result.output
346         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
347         if expected != actual:
348             dump = black.dump_to_file(actual)
349             msg = (
350                 "Expected diff isn't equal to the actual. If you made changes to"
351                 " expression.py and this is an anticipated difference, overwrite"
352                 f" tests/data/expression.diff with {dump}"
353             )
354             self.assertEqual(expected, actual, msg)
355
356     def test_expression_diff_with_color(self) -> None:
357         source, _ = read_data("expression.py")
358         config = THIS_DIR / "data" / "empty_pyproject.toml"
359         expected, _ = read_data("expression.diff")
360         tmp_file = Path(black.dump_to_file(source))
361         try:
362             result = BlackRunner().invoke(
363                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
364             )
365         finally:
366             os.unlink(tmp_file)
367         actual = result.output
368         # We check the contents of the diff in `test_expression_diff`. All
369         # we need to check here is that color codes exist in the result.
370         self.assertIn("\033[1;37m", actual)
371         self.assertIn("\033[36m", actual)
372         self.assertIn("\033[32m", actual)
373         self.assertIn("\033[31m", actual)
374         self.assertIn("\033[0m", actual)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_pep_570(self) -> None:
378         source, expected = read_data("pep_570")
379         actual = fs(source)
380         self.assertFormatEqual(expected, actual)
381         black.assert_stable(source, actual, DEFAULT_MODE)
382         if sys.version_info >= (3, 8):
383             black.assert_equivalent(source, actual)
384
385     def test_detect_pos_only_arguments(self) -> None:
386         source, _ = read_data("pep_570")
387         root = black.lib2to3_parse(source)
388         features = black.get_features_used(root)
389         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
390         versions = black.detect_target_versions(root)
391         self.assertIn(black.TargetVersion.PY38, versions)
392
393     @patch("black.dump_to_file", dump_to_stderr)
394     def test_string_quotes(self) -> None:
395         source, expected = read_data("string_quotes")
396         mode = black.Mode(experimental_string_processing=True)
397         actual = fs(source, mode=mode)
398         self.assertFormatEqual(expected, actual)
399         black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, mode)
401         mode = replace(mode, string_normalization=False)
402         not_normalized = fs(source, mode=mode)
403         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
404         black.assert_equivalent(source, not_normalized)
405         black.assert_stable(source, not_normalized, mode=mode)
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_docstring_no_string_normalization(self) -> None:
409         """Like test_docstring but with string normalization off."""
410         source, expected = read_data("docstring_no_string_normalization")
411         mode = replace(DEFAULT_MODE, string_normalization=False)
412         actual = fs(source, mode=mode)
413         self.assertFormatEqual(expected, actual)
414         black.assert_equivalent(source, actual)
415         black.assert_stable(source, actual, mode)
416
417     def test_long_strings_flag_disabled(self) -> None:
418         """Tests for turning off the string processing logic."""
419         source, expected = read_data("long_strings_flag_disabled")
420         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
421         actual = fs(source, mode=mode)
422         self.assertFormatEqual(expected, actual)
423         black.assert_stable(expected, actual, mode)
424
425     @patch("black.dump_to_file", dump_to_stderr)
426     def test_numeric_literals(self) -> None:
427         source, expected = read_data("numeric_literals")
428         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
429         actual = fs(source, mode=mode)
430         self.assertFormatEqual(expected, actual)
431         black.assert_equivalent(source, actual)
432         black.assert_stable(source, actual, mode)
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_numeric_literals_ignoring_underscores(self) -> None:
436         source, expected = read_data("numeric_literals_skip_underscores")
437         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
438         actual = fs(source, mode=mode)
439         self.assertFormatEqual(expected, actual)
440         black.assert_equivalent(source, actual)
441         black.assert_stable(source, actual, mode)
442
443     def test_skip_magic_trailing_comma(self) -> None:
444         source, _ = read_data("expression.py")
445         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
446         tmp_file = Path(black.dump_to_file(source))
447         diff_header = re.compile(
448             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
449             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
450         )
451         try:
452             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
453             self.assertEqual(result.exit_code, 0)
454         finally:
455             os.unlink(tmp_file)
456         actual = result.output
457         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
458         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
459         if expected != actual:
460             dump = black.dump_to_file(actual)
461             msg = (
462                 "Expected diff isn't equal to the actual. If you made changes to"
463                 " expression.py and this is an anticipated difference, overwrite"
464                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
465             )
466             self.assertEqual(expected, actual, msg)
467
468     @pytest.mark.no_python2
469     def test_python2_should_fail_without_optional_install(self) -> None:
470         if sys.version_info < (3, 8):
471             self.skipTest(
472                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
473                 " able to parse Python 2 syntax without explicitly specifying the"
474                 " python2 extra"
475             )
476
477         source = "x = 1234l"
478         tmp_file = Path(black.dump_to_file(source))
479         try:
480             runner = BlackRunner()
481             result = runner.invoke(black.main, [str(tmp_file)])
482             self.assertEqual(result.exit_code, 123)
483         finally:
484             os.unlink(tmp_file)
485         actual = (
486             runner.stderr_bytes.decode()
487             .replace("\n", "")
488             .replace("\\n", "")
489             .replace("\\r", "")
490             .replace("\r", "")
491         )
492         msg = (
493             "The requested source code has invalid Python 3 syntax."
494             "If you are trying to format Python 2 files please reinstall Black"
495             " with the 'python2' extra: `python3 -m pip install black[python2]`."
496         )
497         self.assertIn(msg, actual)
498
499     @pytest.mark.python2
500     @patch("black.dump_to_file", dump_to_stderr)
501     def test_python2_print_function(self) -> None:
502         source, expected = read_data("python2_print_function")
503         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
504         actual = fs(source, mode=mode)
505         self.assertFormatEqual(expected, actual)
506         black.assert_equivalent(source, actual)
507         black.assert_stable(source, actual, mode)
508
509     @patch("black.dump_to_file", dump_to_stderr)
510     def test_stub(self) -> None:
511         mode = replace(DEFAULT_MODE, is_pyi=True)
512         source, expected = read_data("stub.pyi")
513         actual = fs(source, mode=mode)
514         self.assertFormatEqual(expected, actual)
515         black.assert_stable(source, actual, mode)
516
517     @patch("black.dump_to_file", dump_to_stderr)
518     def test_async_as_identifier(self) -> None:
519         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
520         source, expected = read_data("async_as_identifier")
521         actual = fs(source)
522         self.assertFormatEqual(expected, actual)
523         major, minor = sys.version_info[:2]
524         if major < 3 or (major <= 3 and minor < 7):
525             black.assert_equivalent(source, actual)
526         black.assert_stable(source, actual, DEFAULT_MODE)
527         # ensure black can parse this when the target is 3.6
528         self.invokeBlack([str(source_path), "--target-version", "py36"])
529         # but not on 3.7, because async/await is no longer an identifier
530         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
531
532     @patch("black.dump_to_file", dump_to_stderr)
533     def test_python37(self) -> None:
534         source_path = (THIS_DIR / "data" / "python37.py").resolve()
535         source, expected = read_data("python37")
536         actual = fs(source)
537         self.assertFormatEqual(expected, actual)
538         major, minor = sys.version_info[:2]
539         if major > 3 or (major == 3 and minor >= 7):
540             black.assert_equivalent(source, actual)
541         black.assert_stable(source, actual, DEFAULT_MODE)
542         # ensure black can parse this when the target is 3.7
543         self.invokeBlack([str(source_path), "--target-version", "py37"])
544         # but not on 3.6, because we use async as a reserved keyword
545         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
546
547     @patch("black.dump_to_file", dump_to_stderr)
548     def test_python38(self) -> None:
549         source, expected = read_data("python38")
550         actual = fs(source)
551         self.assertFormatEqual(expected, actual)
552         major, minor = sys.version_info[:2]
553         if major > 3 or (major == 3 and minor >= 8):
554             black.assert_equivalent(source, actual)
555         black.assert_stable(source, actual, DEFAULT_MODE)
556
557     @patch("black.dump_to_file", dump_to_stderr)
558     def test_python39(self) -> None:
559         source, expected = read_data("python39")
560         actual = fs(source)
561         self.assertFormatEqual(expected, actual)
562         major, minor = sys.version_info[:2]
563         if major > 3 or (major == 3 and minor >= 9):
564             black.assert_equivalent(source, actual)
565         black.assert_stable(source, actual, DEFAULT_MODE)
566
567     def test_tab_comment_indentation(self) -> None:
568         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
569         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
570         self.assertFormatEqual(contents_spc, fs(contents_spc))
571         self.assertFormatEqual(contents_spc, fs(contents_tab))
572
573         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
574         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
575         self.assertFormatEqual(contents_spc, fs(contents_spc))
576         self.assertFormatEqual(contents_spc, fs(contents_tab))
577
578         # mixed tabs and spaces (valid Python 2 code)
579         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
580         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
581         self.assertFormatEqual(contents_spc, fs(contents_spc))
582         self.assertFormatEqual(contents_spc, fs(contents_tab))
583
584         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
585         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
586         self.assertFormatEqual(contents_spc, fs(contents_spc))
587         self.assertFormatEqual(contents_spc, fs(contents_tab))
588
589     def test_report_verbose(self) -> None:
590         report = Report(verbose=True)
591         out_lines = []
592         err_lines = []
593
594         def out(msg: str, **kwargs: Any) -> None:
595             out_lines.append(msg)
596
597         def err(msg: str, **kwargs: Any) -> None:
598             err_lines.append(msg)
599
600         with patch("black.output._out", out), patch("black.output._err", err):
601             report.done(Path("f1"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 1)
603             self.assertEqual(len(err_lines), 0)
604             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
605             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
606             self.assertEqual(report.return_code, 0)
607             report.done(Path("f2"), black.Changed.YES)
608             self.assertEqual(len(out_lines), 2)
609             self.assertEqual(len(err_lines), 0)
610             self.assertEqual(out_lines[-1], "reformatted f2")
611             self.assertEqual(
612                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
613             )
614             report.done(Path("f3"), black.Changed.CACHED)
615             self.assertEqual(len(out_lines), 3)
616             self.assertEqual(len(err_lines), 0)
617             self.assertEqual(
618                 out_lines[-1], "f3 wasn't modified on disk since last run."
619             )
620             self.assertEqual(
621                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
622             )
623             self.assertEqual(report.return_code, 0)
624             report.check = True
625             self.assertEqual(report.return_code, 1)
626             report.check = False
627             report.failed(Path("e1"), "boom")
628             self.assertEqual(len(out_lines), 3)
629             self.assertEqual(len(err_lines), 1)
630             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
634                 " reformat.",
635             )
636             self.assertEqual(report.return_code, 123)
637             report.done(Path("f3"), black.Changed.YES)
638             self.assertEqual(len(out_lines), 4)
639             self.assertEqual(len(err_lines), 1)
640             self.assertEqual(out_lines[-1], "reformatted f3")
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
644                 " reformat.",
645             )
646             self.assertEqual(report.return_code, 123)
647             report.failed(Path("e2"), "boom")
648             self.assertEqual(len(out_lines), 4)
649             self.assertEqual(len(err_lines), 2)
650             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
651             self.assertEqual(
652                 unstyle(str(report)),
653                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
654                 " reformat.",
655             )
656             self.assertEqual(report.return_code, 123)
657             report.path_ignored(Path("wat"), "no match")
658             self.assertEqual(len(out_lines), 5)
659             self.assertEqual(len(err_lines), 2)
660             self.assertEqual(out_lines[-1], "wat ignored: no match")
661             self.assertEqual(
662                 unstyle(str(report)),
663                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
664                 " reformat.",
665             )
666             self.assertEqual(report.return_code, 123)
667             report.done(Path("f4"), black.Changed.NO)
668             self.assertEqual(len(out_lines), 6)
669             self.assertEqual(len(err_lines), 2)
670             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
671             self.assertEqual(
672                 unstyle(str(report)),
673                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
674                 " reformat.",
675             )
676             self.assertEqual(report.return_code, 123)
677             report.check = True
678             self.assertEqual(
679                 unstyle(str(report)),
680                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
681                 " would fail to reformat.",
682             )
683             report.check = False
684             report.diff = True
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
688                 " would fail to reformat.",
689             )
690
691     def test_report_quiet(self) -> None:
692         report = Report(quiet=True)
693         out_lines = []
694         err_lines = []
695
696         def out(msg: str, **kwargs: Any) -> None:
697             out_lines.append(msg)
698
699         def err(msg: str, **kwargs: Any) -> None:
700             err_lines.append(msg)
701
702         with patch("black.output._out", out), patch("black.output._err", err):
703             report.done(Path("f1"), black.Changed.NO)
704             self.assertEqual(len(out_lines), 0)
705             self.assertEqual(len(err_lines), 0)
706             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
707             self.assertEqual(report.return_code, 0)
708             report.done(Path("f2"), black.Changed.YES)
709             self.assertEqual(len(out_lines), 0)
710             self.assertEqual(len(err_lines), 0)
711             self.assertEqual(
712                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
713             )
714             report.done(Path("f3"), black.Changed.CACHED)
715             self.assertEqual(len(out_lines), 0)
716             self.assertEqual(len(err_lines), 0)
717             self.assertEqual(
718                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
719             )
720             self.assertEqual(report.return_code, 0)
721             report.check = True
722             self.assertEqual(report.return_code, 1)
723             report.check = False
724             report.failed(Path("e1"), "boom")
725             self.assertEqual(len(out_lines), 0)
726             self.assertEqual(len(err_lines), 1)
727             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
731                 " reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.done(Path("f3"), black.Changed.YES)
735             self.assertEqual(len(out_lines), 0)
736             self.assertEqual(len(err_lines), 1)
737             self.assertEqual(
738                 unstyle(str(report)),
739                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
740                 " reformat.",
741             )
742             self.assertEqual(report.return_code, 123)
743             report.failed(Path("e2"), "boom")
744             self.assertEqual(len(out_lines), 0)
745             self.assertEqual(len(err_lines), 2)
746             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
747             self.assertEqual(
748                 unstyle(str(report)),
749                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
750                 " reformat.",
751             )
752             self.assertEqual(report.return_code, 123)
753             report.path_ignored(Path("wat"), "no match")
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 2)
756             self.assertEqual(
757                 unstyle(str(report)),
758                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
759                 " reformat.",
760             )
761             self.assertEqual(report.return_code, 123)
762             report.done(Path("f4"), black.Changed.NO)
763             self.assertEqual(len(out_lines), 0)
764             self.assertEqual(len(err_lines), 2)
765             self.assertEqual(
766                 unstyle(str(report)),
767                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
768                 " reformat.",
769             )
770             self.assertEqual(report.return_code, 123)
771             report.check = True
772             self.assertEqual(
773                 unstyle(str(report)),
774                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
775                 " would fail to reformat.",
776             )
777             report.check = False
778             report.diff = True
779             self.assertEqual(
780                 unstyle(str(report)),
781                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
782                 " would fail to reformat.",
783             )
784
785     def test_report_normal(self) -> None:
786         report = black.Report()
787         out_lines = []
788         err_lines = []
789
790         def out(msg: str, **kwargs: Any) -> None:
791             out_lines.append(msg)
792
793         def err(msg: str, **kwargs: Any) -> None:
794             err_lines.append(msg)
795
796         with patch("black.output._out", out), patch("black.output._err", err):
797             report.done(Path("f1"), black.Changed.NO)
798             self.assertEqual(len(out_lines), 0)
799             self.assertEqual(len(err_lines), 0)
800             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
801             self.assertEqual(report.return_code, 0)
802             report.done(Path("f2"), black.Changed.YES)
803             self.assertEqual(len(out_lines), 1)
804             self.assertEqual(len(err_lines), 0)
805             self.assertEqual(out_lines[-1], "reformatted f2")
806             self.assertEqual(
807                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
808             )
809             report.done(Path("f3"), black.Changed.CACHED)
810             self.assertEqual(len(out_lines), 1)
811             self.assertEqual(len(err_lines), 0)
812             self.assertEqual(out_lines[-1], "reformatted f2")
813             self.assertEqual(
814                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
815             )
816             self.assertEqual(report.return_code, 0)
817             report.check = True
818             self.assertEqual(report.return_code, 1)
819             report.check = False
820             report.failed(Path("e1"), "boom")
821             self.assertEqual(len(out_lines), 1)
822             self.assertEqual(len(err_lines), 1)
823             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
824             self.assertEqual(
825                 unstyle(str(report)),
826                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
827                 " reformat.",
828             )
829             self.assertEqual(report.return_code, 123)
830             report.done(Path("f3"), black.Changed.YES)
831             self.assertEqual(len(out_lines), 2)
832             self.assertEqual(len(err_lines), 1)
833             self.assertEqual(out_lines[-1], "reformatted f3")
834             self.assertEqual(
835                 unstyle(str(report)),
836                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
837                 " reformat.",
838             )
839             self.assertEqual(report.return_code, 123)
840             report.failed(Path("e2"), "boom")
841             self.assertEqual(len(out_lines), 2)
842             self.assertEqual(len(err_lines), 2)
843             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
844             self.assertEqual(
845                 unstyle(str(report)),
846                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
847                 " reformat.",
848             )
849             self.assertEqual(report.return_code, 123)
850             report.path_ignored(Path("wat"), "no match")
851             self.assertEqual(len(out_lines), 2)
852             self.assertEqual(len(err_lines), 2)
853             self.assertEqual(
854                 unstyle(str(report)),
855                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
856                 " reformat.",
857             )
858             self.assertEqual(report.return_code, 123)
859             report.done(Path("f4"), black.Changed.NO)
860             self.assertEqual(len(out_lines), 2)
861             self.assertEqual(len(err_lines), 2)
862             self.assertEqual(
863                 unstyle(str(report)),
864                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
865                 " reformat.",
866             )
867             self.assertEqual(report.return_code, 123)
868             report.check = True
869             self.assertEqual(
870                 unstyle(str(report)),
871                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
872                 " would fail to reformat.",
873             )
874             report.check = False
875             report.diff = True
876             self.assertEqual(
877                 unstyle(str(report)),
878                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
879                 " would fail to reformat.",
880             )
881
882     def test_lib2to3_parse(self) -> None:
883         with self.assertRaises(black.InvalidInput):
884             black.lib2to3_parse("invalid syntax")
885
886         straddling = "x + y"
887         black.lib2to3_parse(straddling)
888         black.lib2to3_parse(straddling, {TargetVersion.PY27})
889         black.lib2to3_parse(straddling, {TargetVersion.PY36})
890         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
891
892         py2_only = "print x"
893         black.lib2to3_parse(py2_only)
894         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
895         with self.assertRaises(black.InvalidInput):
896             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
897         with self.assertRaises(black.InvalidInput):
898             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
899
900         py3_only = "exec(x, end=y)"
901         black.lib2to3_parse(py3_only)
902         with self.assertRaises(black.InvalidInput):
903             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
904         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
905         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
906
907     def test_get_features_used_decorator(self) -> None:
908         # Test the feature detection of new decorator syntax
909         # since this makes some test cases of test_get_features_used()
910         # fails if it fails, this is tested first so that a useful case
911         # is identified
912         simples, relaxed = read_data("decorators")
913         # skip explanation comments at the top of the file
914         for simple_test in simples.split("##")[1:]:
915             node = black.lib2to3_parse(simple_test)
916             decorator = str(node.children[0].children[0]).strip()
917             self.assertNotIn(
918                 Feature.RELAXED_DECORATORS,
919                 black.get_features_used(node),
920                 msg=(
921                     f"decorator '{decorator}' follows python<=3.8 syntax"
922                     "but is detected as 3.9+"
923                     # f"The full node is\n{node!r}"
924                 ),
925             )
926         # skip the '# output' comment at the top of the output part
927         for relaxed_test in relaxed.split("##")[1:]:
928             node = black.lib2to3_parse(relaxed_test)
929             decorator = str(node.children[0].children[0]).strip()
930             self.assertIn(
931                 Feature.RELAXED_DECORATORS,
932                 black.get_features_used(node),
933                 msg=(
934                     f"decorator '{decorator}' uses python3.9+ syntax"
935                     "but is detected as python<=3.8"
936                     # f"The full node is\n{node!r}"
937                 ),
938             )
939
940     def test_get_features_used(self) -> None:
941         node = black.lib2to3_parse("def f(*, arg): ...\n")
942         self.assertEqual(black.get_features_used(node), set())
943         node = black.lib2to3_parse("def f(*, arg,): ...\n")
944         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
945         node = black.lib2to3_parse("f(*arg,)\n")
946         self.assertEqual(
947             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
948         )
949         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
950         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
951         node = black.lib2to3_parse("123_456\n")
952         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
953         node = black.lib2to3_parse("123456\n")
954         self.assertEqual(black.get_features_used(node), set())
955         source, expected = read_data("function")
956         node = black.lib2to3_parse(source)
957         expected_features = {
958             Feature.TRAILING_COMMA_IN_CALL,
959             Feature.TRAILING_COMMA_IN_DEF,
960             Feature.F_STRINGS,
961         }
962         self.assertEqual(black.get_features_used(node), expected_features)
963         node = black.lib2to3_parse(expected)
964         self.assertEqual(black.get_features_used(node), expected_features)
965         source, expected = read_data("expression")
966         node = black.lib2to3_parse(source)
967         self.assertEqual(black.get_features_used(node), set())
968         node = black.lib2to3_parse(expected)
969         self.assertEqual(black.get_features_used(node), set())
970
971     def test_get_future_imports(self) -> None:
972         node = black.lib2to3_parse("\n")
973         self.assertEqual(set(), black.get_future_imports(node))
974         node = black.lib2to3_parse("from __future__ import black\n")
975         self.assertEqual({"black"}, black.get_future_imports(node))
976         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
977         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
978         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
979         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
980         node = black.lib2to3_parse(
981             "from __future__ import multiple\nfrom __future__ import imports\n"
982         )
983         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
984         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
985         self.assertEqual({"black"}, black.get_future_imports(node))
986         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
987         self.assertEqual({"black"}, black.get_future_imports(node))
988         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
989         self.assertEqual(set(), black.get_future_imports(node))
990         node = black.lib2to3_parse("from some.module import black\n")
991         self.assertEqual(set(), black.get_future_imports(node))
992         node = black.lib2to3_parse(
993             "from __future__ import unicode_literals as _unicode_literals"
994         )
995         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
996         node = black.lib2to3_parse(
997             "from __future__ import unicode_literals as _lol, print"
998         )
999         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1000
1001     def test_debug_visitor(self) -> None:
1002         source, _ = read_data("debug_visitor.py")
1003         expected, _ = read_data("debug_visitor.out")
1004         out_lines = []
1005         err_lines = []
1006
1007         def out(msg: str, **kwargs: Any) -> None:
1008             out_lines.append(msg)
1009
1010         def err(msg: str, **kwargs: Any) -> None:
1011             err_lines.append(msg)
1012
1013         with patch("black.debug.out", out):
1014             DebugVisitor.show(source)
1015         actual = "\n".join(out_lines) + "\n"
1016         log_name = ""
1017         if expected != actual:
1018             log_name = black.dump_to_file(*out_lines)
1019         self.assertEqual(
1020             expected,
1021             actual,
1022             f"AST print out is different. Actual version dumped to {log_name}",
1023         )
1024
1025     def test_format_file_contents(self) -> None:
1026         empty = ""
1027         mode = DEFAULT_MODE
1028         with self.assertRaises(black.NothingChanged):
1029             black.format_file_contents(empty, mode=mode, fast=False)
1030         just_nl = "\n"
1031         with self.assertRaises(black.NothingChanged):
1032             black.format_file_contents(just_nl, mode=mode, fast=False)
1033         same = "j = [1, 2, 3]\n"
1034         with self.assertRaises(black.NothingChanged):
1035             black.format_file_contents(same, mode=mode, fast=False)
1036         different = "j = [1,2,3]"
1037         expected = same
1038         actual = black.format_file_contents(different, mode=mode, fast=False)
1039         self.assertEqual(expected, actual)
1040         invalid = "return if you can"
1041         with self.assertRaises(black.InvalidInput) as e:
1042             black.format_file_contents(invalid, mode=mode, fast=False)
1043         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1044
1045     def test_endmarker(self) -> None:
1046         n = black.lib2to3_parse("\n")
1047         self.assertEqual(n.type, black.syms.file_input)
1048         self.assertEqual(len(n.children), 1)
1049         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1050
1051     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1052     def test_assertFormatEqual(self) -> None:
1053         out_lines = []
1054         err_lines = []
1055
1056         def out(msg: str, **kwargs: Any) -> None:
1057             out_lines.append(msg)
1058
1059         def err(msg: str, **kwargs: Any) -> None:
1060             err_lines.append(msg)
1061
1062         with patch("black.output._out", out), patch("black.output._err", err):
1063             with self.assertRaises(AssertionError):
1064                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1065
1066         out_str = "".join(out_lines)
1067         self.assertTrue("Expected tree:" in out_str)
1068         self.assertTrue("Actual tree:" in out_str)
1069         self.assertEqual("".join(err_lines), "")
1070
1071     def test_cache_broken_file(self) -> None:
1072         mode = DEFAULT_MODE
1073         with cache_dir() as workspace:
1074             cache_file = get_cache_file(mode)
1075             with cache_file.open("w") as fobj:
1076                 fobj.write("this is not a pickle")
1077             self.assertEqual(black.read_cache(mode), {})
1078             src = (workspace / "test.py").resolve()
1079             with src.open("w") as fobj:
1080                 fobj.write("print('hello')")
1081             self.invokeBlack([str(src)])
1082             cache = black.read_cache(mode)
1083             self.assertIn(str(src), cache)
1084
1085     def test_cache_single_file_already_cached(self) -> None:
1086         mode = DEFAULT_MODE
1087         with cache_dir() as workspace:
1088             src = (workspace / "test.py").resolve()
1089             with src.open("w") as fobj:
1090                 fobj.write("print('hello')")
1091             black.write_cache({}, [src], mode)
1092             self.invokeBlack([str(src)])
1093             with src.open("r") as fobj:
1094                 self.assertEqual(fobj.read(), "print('hello')")
1095
1096     @event_loop()
1097     def test_cache_multiple_files(self) -> None:
1098         mode = DEFAULT_MODE
1099         with cache_dir() as workspace, patch(
1100             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1101         ):
1102             one = (workspace / "one.py").resolve()
1103             with one.open("w") as fobj:
1104                 fobj.write("print('hello')")
1105             two = (workspace / "two.py").resolve()
1106             with two.open("w") as fobj:
1107                 fobj.write("print('hello')")
1108             black.write_cache({}, [one], mode)
1109             self.invokeBlack([str(workspace)])
1110             with one.open("r") as fobj:
1111                 self.assertEqual(fobj.read(), "print('hello')")
1112             with two.open("r") as fobj:
1113                 self.assertEqual(fobj.read(), 'print("hello")\n')
1114             cache = black.read_cache(mode)
1115             self.assertIn(str(one), cache)
1116             self.assertIn(str(two), cache)
1117
1118     def test_no_cache_when_writeback_diff(self) -> None:
1119         mode = DEFAULT_MODE
1120         with cache_dir() as workspace:
1121             src = (workspace / "test.py").resolve()
1122             with src.open("w") as fobj:
1123                 fobj.write("print('hello')")
1124             with patch("black.read_cache") as read_cache, patch(
1125                 "black.write_cache"
1126             ) as write_cache:
1127                 self.invokeBlack([str(src), "--diff"])
1128                 cache_file = get_cache_file(mode)
1129                 self.assertFalse(cache_file.exists())
1130                 write_cache.assert_not_called()
1131                 read_cache.assert_not_called()
1132
1133     def test_no_cache_when_writeback_color_diff(self) -> None:
1134         mode = DEFAULT_MODE
1135         with cache_dir() as workspace:
1136             src = (workspace / "test.py").resolve()
1137             with src.open("w") as fobj:
1138                 fobj.write("print('hello')")
1139             with patch("black.read_cache") as read_cache, patch(
1140                 "black.write_cache"
1141             ) as write_cache:
1142                 self.invokeBlack([str(src), "--diff", "--color"])
1143                 cache_file = get_cache_file(mode)
1144                 self.assertFalse(cache_file.exists())
1145                 write_cache.assert_not_called()
1146                 read_cache.assert_not_called()
1147
1148     @event_loop()
1149     def test_output_locking_when_writeback_diff(self) -> None:
1150         with cache_dir() as workspace:
1151             for tag in range(0, 4):
1152                 src = (workspace / f"test{tag}.py").resolve()
1153                 with src.open("w") as fobj:
1154                     fobj.write("print('hello')")
1155             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1156                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1157                 # this isn't quite doing what we want, but if it _isn't_
1158                 # called then we cannot be using the lock it provides
1159                 mgr.assert_called()
1160
1161     @event_loop()
1162     def test_output_locking_when_writeback_color_diff(self) -> None:
1163         with cache_dir() as workspace:
1164             for tag in range(0, 4):
1165                 src = (workspace / f"test{tag}.py").resolve()
1166                 with src.open("w") as fobj:
1167                     fobj.write("print('hello')")
1168             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1169                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1170                 # this isn't quite doing what we want, but if it _isn't_
1171                 # called then we cannot be using the lock it provides
1172                 mgr.assert_called()
1173
1174     def test_no_cache_when_stdin(self) -> None:
1175         mode = DEFAULT_MODE
1176         with cache_dir():
1177             result = CliRunner().invoke(
1178                 black.main, ["-"], input=BytesIO(b"print('hello')")
1179             )
1180             self.assertEqual(result.exit_code, 0)
1181             cache_file = get_cache_file(mode)
1182             self.assertFalse(cache_file.exists())
1183
1184     def test_read_cache_no_cachefile(self) -> None:
1185         mode = DEFAULT_MODE
1186         with cache_dir():
1187             self.assertEqual(black.read_cache(mode), {})
1188
1189     def test_write_cache_read_cache(self) -> None:
1190         mode = DEFAULT_MODE
1191         with cache_dir() as workspace:
1192             src = (workspace / "test.py").resolve()
1193             src.touch()
1194             black.write_cache({}, [src], mode)
1195             cache = black.read_cache(mode)
1196             self.assertIn(str(src), cache)
1197             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1198
1199     def test_filter_cached(self) -> None:
1200         with TemporaryDirectory() as workspace:
1201             path = Path(workspace)
1202             uncached = (path / "uncached").resolve()
1203             cached = (path / "cached").resolve()
1204             cached_but_changed = (path / "changed").resolve()
1205             uncached.touch()
1206             cached.touch()
1207             cached_but_changed.touch()
1208             cache = {
1209                 str(cached): black.get_cache_info(cached),
1210                 str(cached_but_changed): (0.0, 0),
1211             }
1212             todo, done = black.filter_cached(
1213                 cache, {uncached, cached, cached_but_changed}
1214             )
1215             self.assertEqual(todo, {uncached, cached_but_changed})
1216             self.assertEqual(done, {cached})
1217
1218     def test_write_cache_creates_directory_if_needed(self) -> None:
1219         mode = DEFAULT_MODE
1220         with cache_dir(exists=False) as workspace:
1221             self.assertFalse(workspace.exists())
1222             black.write_cache({}, [], mode)
1223             self.assertTrue(workspace.exists())
1224
1225     @event_loop()
1226     def test_failed_formatting_does_not_get_cached(self) -> None:
1227         mode = DEFAULT_MODE
1228         with cache_dir() as workspace, patch(
1229             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1230         ):
1231             failing = (workspace / "failing.py").resolve()
1232             with failing.open("w") as fobj:
1233                 fobj.write("not actually python")
1234             clean = (workspace / "clean.py").resolve()
1235             with clean.open("w") as fobj:
1236                 fobj.write('print("hello")\n')
1237             self.invokeBlack([str(workspace)], exit_code=123)
1238             cache = black.read_cache(mode)
1239             self.assertNotIn(str(failing), cache)
1240             self.assertIn(str(clean), cache)
1241
1242     def test_write_cache_write_fail(self) -> None:
1243         mode = DEFAULT_MODE
1244         with cache_dir(), patch.object(Path, "open") as mock:
1245             mock.side_effect = OSError
1246             black.write_cache({}, [], mode)
1247
1248     @event_loop()
1249     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1250     def test_works_in_mono_process_only_environment(self) -> None:
1251         with cache_dir() as workspace:
1252             for f in [
1253                 (workspace / "one.py").resolve(),
1254                 (workspace / "two.py").resolve(),
1255             ]:
1256                 f.write_text('print("hello")\n')
1257             self.invokeBlack([str(workspace)])
1258
1259     @event_loop()
1260     def test_check_diff_use_together(self) -> None:
1261         with cache_dir():
1262             # Files which will be reformatted.
1263             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1264             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1265             # Files which will not be reformatted.
1266             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1267             self.invokeBlack([str(src2), "--diff", "--check"])
1268             # Multi file command.
1269             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1270
1271     def test_no_files(self) -> None:
1272         with cache_dir():
1273             # Without an argument, black exits with error code 0.
1274             self.invokeBlack([])
1275
1276     def test_broken_symlink(self) -> None:
1277         with cache_dir() as workspace:
1278             symlink = workspace / "broken_link.py"
1279             try:
1280                 symlink.symlink_to("nonexistent.py")
1281             except OSError as e:
1282                 self.skipTest(f"Can't create symlinks: {e}")
1283             self.invokeBlack([str(workspace.resolve())])
1284
1285     def test_read_cache_line_lengths(self) -> None:
1286         mode = DEFAULT_MODE
1287         short_mode = replace(DEFAULT_MODE, line_length=1)
1288         with cache_dir() as workspace:
1289             path = (workspace / "file.py").resolve()
1290             path.touch()
1291             black.write_cache({}, [path], mode)
1292             one = black.read_cache(mode)
1293             self.assertIn(str(path), one)
1294             two = black.read_cache(short_mode)
1295             self.assertNotIn(str(path), two)
1296
1297     def test_single_file_force_pyi(self) -> None:
1298         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1299         contents, expected = read_data("force_pyi")
1300         with cache_dir() as workspace:
1301             path = (workspace / "file.py").resolve()
1302             with open(path, "w") as fh:
1303                 fh.write(contents)
1304             self.invokeBlack([str(path), "--pyi"])
1305             with open(path, "r") as fh:
1306                 actual = fh.read()
1307             # verify cache with --pyi is separate
1308             pyi_cache = black.read_cache(pyi_mode)
1309             self.assertIn(str(path), pyi_cache)
1310             normal_cache = black.read_cache(DEFAULT_MODE)
1311             self.assertNotIn(str(path), normal_cache)
1312         self.assertFormatEqual(expected, actual)
1313         black.assert_equivalent(contents, actual)
1314         black.assert_stable(contents, actual, pyi_mode)
1315
1316     @event_loop()
1317     def test_multi_file_force_pyi(self) -> None:
1318         reg_mode = DEFAULT_MODE
1319         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1320         contents, expected = read_data("force_pyi")
1321         with cache_dir() as workspace:
1322             paths = [
1323                 (workspace / "file1.py").resolve(),
1324                 (workspace / "file2.py").resolve(),
1325             ]
1326             for path in paths:
1327                 with open(path, "w") as fh:
1328                     fh.write(contents)
1329             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1330             for path in paths:
1331                 with open(path, "r") as fh:
1332                     actual = fh.read()
1333                 self.assertEqual(actual, expected)
1334             # verify cache with --pyi is separate
1335             pyi_cache = black.read_cache(pyi_mode)
1336             normal_cache = black.read_cache(reg_mode)
1337             for path in paths:
1338                 self.assertIn(str(path), pyi_cache)
1339                 self.assertNotIn(str(path), normal_cache)
1340
1341     def test_pipe_force_pyi(self) -> None:
1342         source, expected = read_data("force_pyi")
1343         result = CliRunner().invoke(
1344             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1345         )
1346         self.assertEqual(result.exit_code, 0)
1347         actual = result.output
1348         self.assertFormatEqual(actual, expected)
1349
1350     def test_single_file_force_py36(self) -> None:
1351         reg_mode = DEFAULT_MODE
1352         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1353         source, expected = read_data("force_py36")
1354         with cache_dir() as workspace:
1355             path = (workspace / "file.py").resolve()
1356             with open(path, "w") as fh:
1357                 fh.write(source)
1358             self.invokeBlack([str(path), *PY36_ARGS])
1359             with open(path, "r") as fh:
1360                 actual = fh.read()
1361             # verify cache with --target-version is separate
1362             py36_cache = black.read_cache(py36_mode)
1363             self.assertIn(str(path), py36_cache)
1364             normal_cache = black.read_cache(reg_mode)
1365             self.assertNotIn(str(path), normal_cache)
1366         self.assertEqual(actual, expected)
1367
1368     @event_loop()
1369     def test_multi_file_force_py36(self) -> None:
1370         reg_mode = DEFAULT_MODE
1371         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1372         source, expected = read_data("force_py36")
1373         with cache_dir() as workspace:
1374             paths = [
1375                 (workspace / "file1.py").resolve(),
1376                 (workspace / "file2.py").resolve(),
1377             ]
1378             for path in paths:
1379                 with open(path, "w") as fh:
1380                     fh.write(source)
1381             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1382             for path in paths:
1383                 with open(path, "r") as fh:
1384                     actual = fh.read()
1385                 self.assertEqual(actual, expected)
1386             # verify cache with --target-version is separate
1387             pyi_cache = black.read_cache(py36_mode)
1388             normal_cache = black.read_cache(reg_mode)
1389             for path in paths:
1390                 self.assertIn(str(path), pyi_cache)
1391                 self.assertNotIn(str(path), normal_cache)
1392
1393     def test_pipe_force_py36(self) -> None:
1394         source, expected = read_data("force_py36")
1395         result = CliRunner().invoke(
1396             black.main,
1397             ["-", "-q", "--target-version=py36"],
1398             input=BytesIO(source.encode("utf8")),
1399         )
1400         self.assertEqual(result.exit_code, 0)
1401         actual = result.output
1402         self.assertFormatEqual(actual, expected)
1403
1404     def test_include_exclude(self) -> None:
1405         path = THIS_DIR / "data" / "include_exclude_tests"
1406         include = re.compile(r"\.pyi?$")
1407         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1408         report = black.Report()
1409         gitignore = PathSpec.from_lines("gitwildmatch", [])
1410         sources: List[Path] = []
1411         expected = [
1412             Path(path / "b/dont_exclude/a.py"),
1413             Path(path / "b/dont_exclude/a.pyi"),
1414         ]
1415         this_abs = THIS_DIR.resolve()
1416         sources.extend(
1417             black.gen_python_files(
1418                 path.iterdir(),
1419                 this_abs,
1420                 include,
1421                 exclude,
1422                 None,
1423                 None,
1424                 report,
1425                 gitignore,
1426             )
1427         )
1428         self.assertEqual(sorted(expected), sorted(sources))
1429
1430     def test_gitingore_used_as_default(self) -> None:
1431         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1432         include = re.compile(r"\.pyi?$")
1433         extend_exclude = re.compile(r"/exclude/")
1434         src = str(path / "b/")
1435         report = black.Report()
1436         expected: List[Path] = [
1437             path / "b/.definitely_exclude/a.py",
1438             path / "b/.definitely_exclude/a.pyi",
1439         ]
1440         sources = list(
1441             black.get_sources(
1442                 ctx=FakeContext(),
1443                 src=(src,),
1444                 quiet=True,
1445                 verbose=False,
1446                 include=include,
1447                 exclude=None,
1448                 extend_exclude=extend_exclude,
1449                 force_exclude=None,
1450                 report=report,
1451                 stdin_filename=None,
1452             )
1453         )
1454         self.assertEqual(sorted(expected), sorted(sources))
1455
1456     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1457     def test_exclude_for_issue_1572(self) -> None:
1458         # Exclude shouldn't touch files that were explicitly given to Black through the
1459         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1460         # https://github.com/psf/black/issues/1572
1461         path = THIS_DIR / "data" / "include_exclude_tests"
1462         include = ""
1463         exclude = r"/exclude/|a\.py"
1464         src = str(path / "b/exclude/a.py")
1465         report = black.Report()
1466         expected = [Path(path / "b/exclude/a.py")]
1467         sources = list(
1468             black.get_sources(
1469                 ctx=FakeContext(),
1470                 src=(src,),
1471                 quiet=True,
1472                 verbose=False,
1473                 include=re.compile(include),
1474                 exclude=re.compile(exclude),
1475                 extend_exclude=None,
1476                 force_exclude=None,
1477                 report=report,
1478                 stdin_filename=None,
1479             )
1480         )
1481         self.assertEqual(sorted(expected), sorted(sources))
1482
1483     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1484     def test_get_sources_with_stdin(self) -> None:
1485         include = ""
1486         exclude = r"/exclude/|a\.py"
1487         src = "-"
1488         report = black.Report()
1489         expected = [Path("-")]
1490         sources = list(
1491             black.get_sources(
1492                 ctx=FakeContext(),
1493                 src=(src,),
1494                 quiet=True,
1495                 verbose=False,
1496                 include=re.compile(include),
1497                 exclude=re.compile(exclude),
1498                 extend_exclude=None,
1499                 force_exclude=None,
1500                 report=report,
1501                 stdin_filename=None,
1502             )
1503         )
1504         self.assertEqual(sorted(expected), sorted(sources))
1505
1506     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1507     def test_get_sources_with_stdin_filename(self) -> None:
1508         include = ""
1509         exclude = r"/exclude/|a\.py"
1510         src = "-"
1511         report = black.Report()
1512         stdin_filename = str(THIS_DIR / "data/collections.py")
1513         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1514         sources = list(
1515             black.get_sources(
1516                 ctx=FakeContext(),
1517                 src=(src,),
1518                 quiet=True,
1519                 verbose=False,
1520                 include=re.compile(include),
1521                 exclude=re.compile(exclude),
1522                 extend_exclude=None,
1523                 force_exclude=None,
1524                 report=report,
1525                 stdin_filename=stdin_filename,
1526             )
1527         )
1528         self.assertEqual(sorted(expected), sorted(sources))
1529
1530     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1531     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1532         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1533         # file being passed directly. This is the same as
1534         # test_exclude_for_issue_1572
1535         path = THIS_DIR / "data" / "include_exclude_tests"
1536         include = ""
1537         exclude = r"/exclude/|a\.py"
1538         src = "-"
1539         report = black.Report()
1540         stdin_filename = str(path / "b/exclude/a.py")
1541         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1542         sources = list(
1543             black.get_sources(
1544                 ctx=FakeContext(),
1545                 src=(src,),
1546                 quiet=True,
1547                 verbose=False,
1548                 include=re.compile(include),
1549                 exclude=re.compile(exclude),
1550                 extend_exclude=None,
1551                 force_exclude=None,
1552                 report=report,
1553                 stdin_filename=stdin_filename,
1554             )
1555         )
1556         self.assertEqual(sorted(expected), sorted(sources))
1557
1558     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1559     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1560         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1561         # file being passed directly. This is the same as
1562         # test_exclude_for_issue_1572
1563         path = THIS_DIR / "data" / "include_exclude_tests"
1564         include = ""
1565         extend_exclude = r"/exclude/|a\.py"
1566         src = "-"
1567         report = black.Report()
1568         stdin_filename = str(path / "b/exclude/a.py")
1569         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1570         sources = list(
1571             black.get_sources(
1572                 ctx=FakeContext(),
1573                 src=(src,),
1574                 quiet=True,
1575                 verbose=False,
1576                 include=re.compile(include),
1577                 exclude=re.compile(""),
1578                 extend_exclude=re.compile(extend_exclude),
1579                 force_exclude=None,
1580                 report=report,
1581                 stdin_filename=stdin_filename,
1582             )
1583         )
1584         self.assertEqual(sorted(expected), sorted(sources))
1585
1586     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1587     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1588         # Force exclude should exclude the file when passing it through
1589         # stdin_filename
1590         path = THIS_DIR / "data" / "include_exclude_tests"
1591         include = ""
1592         force_exclude = r"/exclude/|a\.py"
1593         src = "-"
1594         report = black.Report()
1595         stdin_filename = str(path / "b/exclude/a.py")
1596         sources = list(
1597             black.get_sources(
1598                 ctx=FakeContext(),
1599                 src=(src,),
1600                 quiet=True,
1601                 verbose=False,
1602                 include=re.compile(include),
1603                 exclude=re.compile(""),
1604                 extend_exclude=None,
1605                 force_exclude=re.compile(force_exclude),
1606                 report=report,
1607                 stdin_filename=stdin_filename,
1608             )
1609         )
1610         self.assertEqual([], sorted(sources))
1611
1612     def test_reformat_one_with_stdin(self) -> None:
1613         with patch(
1614             "black.format_stdin_to_stdout",
1615             return_value=lambda *args, **kwargs: black.Changed.YES,
1616         ) as fsts:
1617             report = MagicMock()
1618             path = Path("-")
1619             black.reformat_one(
1620                 path,
1621                 fast=True,
1622                 write_back=black.WriteBack.YES,
1623                 mode=DEFAULT_MODE,
1624                 report=report,
1625             )
1626             fsts.assert_called_once()
1627             report.done.assert_called_with(path, black.Changed.YES)
1628
1629     def test_reformat_one_with_stdin_filename(self) -> None:
1630         with patch(
1631             "black.format_stdin_to_stdout",
1632             return_value=lambda *args, **kwargs: black.Changed.YES,
1633         ) as fsts:
1634             report = MagicMock()
1635             p = "foo.py"
1636             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1637             expected = Path(p)
1638             black.reformat_one(
1639                 path,
1640                 fast=True,
1641                 write_back=black.WriteBack.YES,
1642                 mode=DEFAULT_MODE,
1643                 report=report,
1644             )
1645             fsts.assert_called_once_with(
1646                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1647             )
1648             # __BLACK_STDIN_FILENAME__ should have been stripped
1649             report.done.assert_called_with(expected, black.Changed.YES)
1650
1651     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1652         with patch(
1653             "black.format_stdin_to_stdout",
1654             return_value=lambda *args, **kwargs: black.Changed.YES,
1655         ) as fsts:
1656             report = MagicMock()
1657             p = "foo.pyi"
1658             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1659             expected = Path(p)
1660             black.reformat_one(
1661                 path,
1662                 fast=True,
1663                 write_back=black.WriteBack.YES,
1664                 mode=DEFAULT_MODE,
1665                 report=report,
1666             )
1667             fsts.assert_called_once_with(
1668                 fast=True,
1669                 write_back=black.WriteBack.YES,
1670                 mode=replace(DEFAULT_MODE, is_pyi=True),
1671             )
1672             # __BLACK_STDIN_FILENAME__ should have been stripped
1673             report.done.assert_called_with(expected, black.Changed.YES)
1674
1675     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1676         with patch(
1677             "black.format_stdin_to_stdout",
1678             return_value=lambda *args, **kwargs: black.Changed.YES,
1679         ) as fsts:
1680             report = MagicMock()
1681             # Even with an existing file, since we are forcing stdin, black
1682             # should output to stdout and not modify the file inplace
1683             p = Path(str(THIS_DIR / "data/collections.py"))
1684             # Make sure is_file actually returns True
1685             self.assertTrue(p.is_file())
1686             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1687             expected = Path(p)
1688             black.reformat_one(
1689                 path,
1690                 fast=True,
1691                 write_back=black.WriteBack.YES,
1692                 mode=DEFAULT_MODE,
1693                 report=report,
1694             )
1695             fsts.assert_called_once()
1696             # __BLACK_STDIN_FILENAME__ should have been stripped
1697             report.done.assert_called_with(expected, black.Changed.YES)
1698
1699     def test_gitignore_exclude(self) -> None:
1700         path = THIS_DIR / "data" / "include_exclude_tests"
1701         include = re.compile(r"\.pyi?$")
1702         exclude = re.compile(r"")
1703         report = black.Report()
1704         gitignore = PathSpec.from_lines(
1705             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1706         )
1707         sources: List[Path] = []
1708         expected = [
1709             Path(path / "b/dont_exclude/a.py"),
1710             Path(path / "b/dont_exclude/a.pyi"),
1711         ]
1712         this_abs = THIS_DIR.resolve()
1713         sources.extend(
1714             black.gen_python_files(
1715                 path.iterdir(),
1716                 this_abs,
1717                 include,
1718                 exclude,
1719                 None,
1720                 None,
1721                 report,
1722                 gitignore,
1723             )
1724         )
1725         self.assertEqual(sorted(expected), sorted(sources))
1726
1727     def test_empty_include(self) -> None:
1728         path = THIS_DIR / "data" / "include_exclude_tests"
1729         report = black.Report()
1730         gitignore = PathSpec.from_lines("gitwildmatch", [])
1731         empty = re.compile(r"")
1732         sources: List[Path] = []
1733         expected = [
1734             Path(path / "b/exclude/a.pie"),
1735             Path(path / "b/exclude/a.py"),
1736             Path(path / "b/exclude/a.pyi"),
1737             Path(path / "b/dont_exclude/a.pie"),
1738             Path(path / "b/dont_exclude/a.py"),
1739             Path(path / "b/dont_exclude/a.pyi"),
1740             Path(path / "b/.definitely_exclude/a.pie"),
1741             Path(path / "b/.definitely_exclude/a.py"),
1742             Path(path / "b/.definitely_exclude/a.pyi"),
1743             Path(path / ".gitignore"),
1744             Path(path / "pyproject.toml"),
1745         ]
1746         this_abs = THIS_DIR.resolve()
1747         sources.extend(
1748             black.gen_python_files(
1749                 path.iterdir(),
1750                 this_abs,
1751                 empty,
1752                 re.compile(black.DEFAULT_EXCLUDES),
1753                 None,
1754                 None,
1755                 report,
1756                 gitignore,
1757             )
1758         )
1759         self.assertEqual(sorted(expected), sorted(sources))
1760
1761     def test_extend_exclude(self) -> None:
1762         path = THIS_DIR / "data" / "include_exclude_tests"
1763         report = black.Report()
1764         gitignore = PathSpec.from_lines("gitwildmatch", [])
1765         sources: List[Path] = []
1766         expected = [
1767             Path(path / "b/exclude/a.py"),
1768             Path(path / "b/dont_exclude/a.py"),
1769         ]
1770         this_abs = THIS_DIR.resolve()
1771         sources.extend(
1772             black.gen_python_files(
1773                 path.iterdir(),
1774                 this_abs,
1775                 re.compile(black.DEFAULT_INCLUDES),
1776                 re.compile(r"\.pyi$"),
1777                 re.compile(r"\.definitely_exclude"),
1778                 None,
1779                 report,
1780                 gitignore,
1781             )
1782         )
1783         self.assertEqual(sorted(expected), sorted(sources))
1784
1785     def test_invalid_cli_regex(self) -> None:
1786         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1787             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1788
1789     def test_preserves_line_endings(self) -> None:
1790         with TemporaryDirectory() as workspace:
1791             test_file = Path(workspace) / "test.py"
1792             for nl in ["\n", "\r\n"]:
1793                 contents = nl.join(["def f(  ):", "    pass"])
1794                 test_file.write_bytes(contents.encode())
1795                 ff(test_file, write_back=black.WriteBack.YES)
1796                 updated_contents: bytes = test_file.read_bytes()
1797                 self.assertIn(nl.encode(), updated_contents)
1798                 if nl == "\n":
1799                     self.assertNotIn(b"\r\n", updated_contents)
1800
1801     def test_preserves_line_endings_via_stdin(self) -> None:
1802         for nl in ["\n", "\r\n"]:
1803             contents = nl.join(["def f(  ):", "    pass"])
1804             runner = BlackRunner()
1805             result = runner.invoke(
1806                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1807             )
1808             self.assertEqual(result.exit_code, 0)
1809             output = runner.stdout_bytes
1810             self.assertIn(nl.encode("utf8"), output)
1811             if nl == "\n":
1812                 self.assertNotIn(b"\r\n", output)
1813
1814     def test_assert_equivalent_different_asts(self) -> None:
1815         with self.assertRaises(AssertionError):
1816             black.assert_equivalent("{}", "None")
1817
1818     def test_symlink_out_of_root_directory(self) -> None:
1819         path = MagicMock()
1820         root = THIS_DIR.resolve()
1821         child = MagicMock()
1822         include = re.compile(black.DEFAULT_INCLUDES)
1823         exclude = re.compile(black.DEFAULT_EXCLUDES)
1824         report = black.Report()
1825         gitignore = PathSpec.from_lines("gitwildmatch", [])
1826         # `child` should behave like a symlink which resolved path is clearly
1827         # outside of the `root` directory.
1828         path.iterdir.return_value = [child]
1829         child.resolve.return_value = Path("/a/b/c")
1830         child.as_posix.return_value = "/a/b/c"
1831         child.is_symlink.return_value = True
1832         try:
1833             list(
1834                 black.gen_python_files(
1835                     path.iterdir(),
1836                     root,
1837                     include,
1838                     exclude,
1839                     None,
1840                     None,
1841                     report,
1842                     gitignore,
1843                 )
1844             )
1845         except ValueError as ve:
1846             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1847         path.iterdir.assert_called_once()
1848         child.resolve.assert_called_once()
1849         child.is_symlink.assert_called_once()
1850         # `child` should behave like a strange file which resolved path is clearly
1851         # outside of the `root` directory.
1852         child.is_symlink.return_value = False
1853         with self.assertRaises(ValueError):
1854             list(
1855                 black.gen_python_files(
1856                     path.iterdir(),
1857                     root,
1858                     include,
1859                     exclude,
1860                     None,
1861                     None,
1862                     report,
1863                     gitignore,
1864                 )
1865             )
1866         path.iterdir.assert_called()
1867         self.assertEqual(path.iterdir.call_count, 2)
1868         child.resolve.assert_called()
1869         self.assertEqual(child.resolve.call_count, 2)
1870         child.is_symlink.assert_called()
1871         self.assertEqual(child.is_symlink.call_count, 2)
1872
1873     def test_shhh_click(self) -> None:
1874         try:
1875             from click import _unicodefun  # type: ignore
1876         except ModuleNotFoundError:
1877             self.skipTest("Incompatible Click version")
1878         if not hasattr(_unicodefun, "_verify_python3_env"):
1879             self.skipTest("Incompatible Click version")
1880         # First, let's see if Click is crashing with a preferred ASCII charset.
1881         with patch("locale.getpreferredencoding") as gpe:
1882             gpe.return_value = "ASCII"
1883             with self.assertRaises(RuntimeError):
1884                 _unicodefun._verify_python3_env()
1885         # Now, let's silence Click...
1886         black.patch_click()
1887         # ...and confirm it's silent.
1888         with patch("locale.getpreferredencoding") as gpe:
1889             gpe.return_value = "ASCII"
1890             try:
1891                 _unicodefun._verify_python3_env()
1892             except RuntimeError as re:
1893                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1894
1895     def test_root_logger_not_used_directly(self) -> None:
1896         def fail(*args: Any, **kwargs: Any) -> None:
1897             self.fail("Record created with root logger")
1898
1899         with patch.multiple(
1900             logging.root,
1901             debug=fail,
1902             info=fail,
1903             warning=fail,
1904             error=fail,
1905             critical=fail,
1906             log=fail,
1907         ):
1908             ff(THIS_DIR / "util.py")
1909
1910     def test_invalid_config_return_code(self) -> None:
1911         tmp_file = Path(black.dump_to_file())
1912         try:
1913             tmp_config = Path(black.dump_to_file())
1914             tmp_config.unlink()
1915             args = ["--config", str(tmp_config), str(tmp_file)]
1916             self.invokeBlack(args, exit_code=2, ignore_config=False)
1917         finally:
1918             tmp_file.unlink()
1919
1920     def test_parse_pyproject_toml(self) -> None:
1921         test_toml_file = THIS_DIR / "test.toml"
1922         config = black.parse_pyproject_toml(str(test_toml_file))
1923         self.assertEqual(config["verbose"], 1)
1924         self.assertEqual(config["check"], "no")
1925         self.assertEqual(config["diff"], "y")
1926         self.assertEqual(config["color"], True)
1927         self.assertEqual(config["line_length"], 79)
1928         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1929         self.assertEqual(config["exclude"], r"\.pyi?$")
1930         self.assertEqual(config["include"], r"\.py?$")
1931
1932     def test_read_pyproject_toml(self) -> None:
1933         test_toml_file = THIS_DIR / "test.toml"
1934         fake_ctx = FakeContext()
1935         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1936         config = fake_ctx.default_map
1937         self.assertEqual(config["verbose"], "1")
1938         self.assertEqual(config["check"], "no")
1939         self.assertEqual(config["diff"], "y")
1940         self.assertEqual(config["color"], "True")
1941         self.assertEqual(config["line_length"], "79")
1942         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1943         self.assertEqual(config["exclude"], r"\.pyi?$")
1944         self.assertEqual(config["include"], r"\.py?$")
1945
1946     def test_find_project_root(self) -> None:
1947         with TemporaryDirectory() as workspace:
1948             root = Path(workspace)
1949             test_dir = root / "test"
1950             test_dir.mkdir()
1951
1952             src_dir = root / "src"
1953             src_dir.mkdir()
1954
1955             root_pyproject = root / "pyproject.toml"
1956             root_pyproject.touch()
1957             src_pyproject = src_dir / "pyproject.toml"
1958             src_pyproject.touch()
1959             src_python = src_dir / "foo.py"
1960             src_python.touch()
1961
1962             self.assertEqual(
1963                 black.find_project_root((src_dir, test_dir)), root.resolve()
1964             )
1965             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1966             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1967
1968     @patch(
1969         "black.files.find_user_pyproject_toml",
1970         black.files.find_user_pyproject_toml.__wrapped__,
1971     )
1972     def test_find_user_pyproject_toml_linux(self) -> None:
1973         if system() == "Windows":
1974             return
1975
1976         # Test if XDG_CONFIG_HOME is checked
1977         with TemporaryDirectory() as workspace:
1978             tmp_user_config = Path(workspace) / "black"
1979             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1980                 self.assertEqual(
1981                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1982                 )
1983
1984         # Test fallback for XDG_CONFIG_HOME
1985         with patch.dict("os.environ"):
1986             os.environ.pop("XDG_CONFIG_HOME", None)
1987             fallback_user_config = Path("~/.config").expanduser() / "black"
1988             self.assertEqual(
1989                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1990             )
1991
1992     def test_find_user_pyproject_toml_windows(self) -> None:
1993         if system() != "Windows":
1994             return
1995
1996         user_config_path = Path.home() / ".black"
1997         self.assertEqual(
1998             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1999         )
2000
2001     def test_bpo_33660_workaround(self) -> None:
2002         if system() == "Windows":
2003             return
2004
2005         # https://bugs.python.org/issue33660
2006
2007         old_cwd = Path.cwd()
2008         try:
2009             root = Path("/")
2010             os.chdir(str(root))
2011             path = Path("workspace") / "project"
2012             report = black.Report(verbose=True)
2013             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2014             self.assertEqual(normalized_path, "workspace/project")
2015         finally:
2016             os.chdir(str(old_cwd))
2017
2018     def test_newline_comment_interaction(self) -> None:
2019         source = "class A:\\\r\n# type: ignore\n pass\n"
2020         output = black.format_str(source, mode=DEFAULT_MODE)
2021         black.assert_stable(source, output, mode=DEFAULT_MODE)
2022
2023     def test_bpo_2142_workaround(self) -> None:
2024
2025         # https://bugs.python.org/issue2142
2026
2027         source, _ = read_data("missing_final_newline.py")
2028         # read_data adds a trailing newline
2029         source = source.rstrip()
2030         expected, _ = read_data("missing_final_newline.diff")
2031         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2032         diff_header = re.compile(
2033             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2034             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2035         )
2036         try:
2037             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2038             self.assertEqual(result.exit_code, 0)
2039         finally:
2040             os.unlink(tmp_file)
2041         actual = result.output
2042         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2043         self.assertEqual(actual, expected)
2044
2045     @pytest.mark.python2
2046     def test_docstring_reformat_for_py27(self) -> None:
2047         """
2048         Check that stripping trailing whitespace from Python 2 docstrings
2049         doesn't trigger a "not equivalent to source" error
2050         """
2051         source = (
2052             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2053         )
2054         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2055
2056         result = CliRunner().invoke(
2057             black.main,
2058             ["-", "-q", "--target-version=py27"],
2059             input=BytesIO(source),
2060         )
2061
2062         self.assertEqual(result.exit_code, 0)
2063         actual = result.output
2064         self.assertFormatEqual(actual, expected)
2065
2066
2067 with open(black.__file__, "r", encoding="utf-8") as _bf:
2068     black_source_lines = _bf.readlines()
2069
2070
2071 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2072     """Show function calls `from black/__init__.py` as they happen.
2073
2074     Register this with `sys.settrace()` in a test you're debugging.
2075     """
2076     if event != "call":
2077         return tracefunc
2078
2079     stack = len(inspect.stack()) - 19
2080     stack *= 2
2081     filename = frame.f_code.co_filename
2082     lineno = frame.f_lineno
2083     func_sig_lineno = lineno - 1
2084     funcname = black_source_lines[func_sig_lineno].strip()
2085     while funcname.startswith("@"):
2086         func_sig_lineno += 1
2087         funcname = black_source_lines[func_sig_lineno].strip()
2088     if "black/__init__.py" in filename:
2089         print(f"{' ' * stack}{lineno}:{funcname}")
2090     return tracefunc
2091
2092
2093 if __name__ == "__main__":
2094     unittest.main(module="test_black")