]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

5c720507216c9efcdd4844d3e95063c011c59f5c
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 import io
10 from io import BytesIO
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     Callable,
21     Dict,
22     List,
23     Iterator,
24     TypeVar,
25 )
26 import pytest
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36 from black.cache import get_cache_file
37 from black.debug import DebugVisitor
38 from black.output import diff, color_diff
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     change_directory,
48     read_data,
49     DETERMINISTIC_HEADER,
50     BlackBaseTestCase,
51     DEFAULT_MODE,
52     fs,
53     ff,
54     dump_to_stderr,
55 )
56
57
58 THIS_FILE = Path(__file__)
59 PY36_VERSIONS = {
60     TargetVersion.PY36,
61     TargetVersion.PY37,
62     TargetVersion.PY38,
63     TargetVersion.PY39,
64 }
65 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
66 T = TypeVar("T")
67 R = TypeVar("R")
68
69 # Match the time output in a diff, but nothing else
70 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
71
72
73 @contextmanager
74 def cache_dir(exists: bool = True) -> Iterator[Path]:
75     with TemporaryDirectory() as workspace:
76         cache_dir = Path(workspace)
77         if not exists:
78             cache_dir = cache_dir / "new"
79         with patch("black.cache.CACHE_DIR", cache_dir):
80             yield cache_dir
81
82
83 @contextmanager
84 def event_loop() -> Iterator[None]:
85     policy = asyncio.get_event_loop_policy()
86     loop = policy.new_event_loop()
87     asyncio.set_event_loop(loop)
88     try:
89         yield
90
91     finally:
92         loop.close()
93
94
95 class FakeContext(click.Context):
96     """A fake click Context for when calling functions that need it."""
97
98     def __init__(self) -> None:
99         self.default_map: Dict[str, Any] = {}
100
101
102 class FakeParameter(click.Parameter):
103     """A fake click Parameter for when calling functions that need it."""
104
105     def __init__(self) -> None:
106         pass
107
108
109 class BlackRunner(CliRunner):
110     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
111
112     def __init__(self) -> None:
113         super().__init__(mix_stderr=False)
114
115
116 class BlackTestCase(BlackBaseTestCase):
117     def invokeBlack(
118         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
119     ) -> None:
120         runner = BlackRunner()
121         if ignore_config:
122             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
123         result = runner.invoke(black.main, args)
124         assert result.stdout_bytes is not None
125         assert result.stderr_bytes is not None
126         self.assertEqual(
127             result.exit_code,
128             exit_code,
129             msg=(
130                 f"Failed with args: {args}\n"
131                 f"stdout: {result.stdout_bytes.decode()!r}\n"
132                 f"stderr: {result.stderr_bytes.decode()!r}\n"
133                 f"exception: {result.exception}"
134             ),
135         )
136
137     @patch("black.dump_to_file", dump_to_stderr)
138     def test_empty(self) -> None:
139         source = expected = ""
140         actual = fs(source)
141         self.assertFormatEqual(expected, actual)
142         black.assert_equivalent(source, actual)
143         black.assert_stable(source, actual, DEFAULT_MODE)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_piping(self) -> None:
157         source, expected = read_data("src/black/__init__", data=False)
158         result = BlackRunner().invoke(
159             black.main,
160             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
161             input=BytesIO(source.encode("utf8")),
162         )
163         self.assertEqual(result.exit_code, 0)
164         self.assertFormatEqual(expected, result.output)
165         if source != result.output:
166             black.assert_equivalent(source, result.output)
167             black.assert_stable(source, result.output, DEFAULT_MODE)
168
169     def test_piping_diff(self) -> None:
170         diff_header = re.compile(
171             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
172             r"\+\d\d\d\d"
173         )
174         source, _ = read_data("expression.py")
175         expected, _ = read_data("expression.diff")
176         config = THIS_DIR / "data" / "empty_pyproject.toml"
177         args = [
178             "-",
179             "--fast",
180             f"--line-length={black.DEFAULT_LINE_LENGTH}",
181             "--diff",
182             f"--config={config}",
183         ]
184         result = BlackRunner().invoke(
185             black.main, args, input=BytesIO(source.encode("utf8"))
186         )
187         self.assertEqual(result.exit_code, 0)
188         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
189         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
190         self.assertEqual(expected, actual)
191
192     def test_piping_diff_with_color(self) -> None:
193         source, _ = read_data("expression.py")
194         config = THIS_DIR / "data" / "empty_pyproject.toml"
195         args = [
196             "-",
197             "--fast",
198             f"--line-length={black.DEFAULT_LINE_LENGTH}",
199             "--diff",
200             "--color",
201             f"--config={config}",
202         ]
203         result = BlackRunner().invoke(
204             black.main, args, input=BytesIO(source.encode("utf8"))
205         )
206         actual = result.output
207         # Again, the contents are checked in a different test, so only look for colors.
208         self.assertIn("\033[1;37m", actual)
209         self.assertIn("\033[36m", actual)
210         self.assertIn("\033[32m", actual)
211         self.assertIn("\033[31m", actual)
212         self.assertIn("\033[0m", actual)
213
214     @patch("black.dump_to_file", dump_to_stderr)
215     def _test_wip(self) -> None:
216         source, expected = read_data("wip")
217         sys.settrace(tracefunc)
218         mode = replace(
219             DEFAULT_MODE,
220             experimental_string_processing=False,
221             target_versions={black.TargetVersion.PY38},
222         )
223         actual = fs(source, mode=mode)
224         sys.settrace(None)
225         self.assertFormatEqual(expected, actual)
226         black.assert_equivalent(source, actual)
227         black.assert_stable(source, actual, black.FileMode())
228
229     @unittest.expectedFailure
230     @patch("black.dump_to_file", dump_to_stderr)
231     def test_trailing_comma_optional_parens_stability1(self) -> None:
232         source, _expected = read_data("trailing_comma_optional_parens1")
233         actual = fs(source)
234         black.assert_stable(source, actual, DEFAULT_MODE)
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability2(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens2")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability3(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens3")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
252         source, _expected = read_data("trailing_comma_optional_parens1")
253         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
254         black.assert_stable(source, actual, DEFAULT_MODE)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
258         source, _expected = read_data("trailing_comma_optional_parens2")
259         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
260         black.assert_stable(source, actual, DEFAULT_MODE)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
264         source, _expected = read_data("trailing_comma_optional_parens3")
265         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
266         black.assert_stable(source, actual, DEFAULT_MODE)
267
268     @patch("black.dump_to_file", dump_to_stderr)
269     def test_pep_572(self) -> None:
270         source, expected = read_data("pep_572")
271         actual = fs(source)
272         self.assertFormatEqual(expected, actual)
273         black.assert_stable(source, actual, DEFAULT_MODE)
274         if sys.version_info >= (3, 8):
275             black.assert_equivalent(source, actual)
276
277     @patch("black.dump_to_file", dump_to_stderr)
278     def test_pep_572_remove_parens(self) -> None:
279         source, expected = read_data("pep_572_remove_parens")
280         actual = fs(source)
281         self.assertFormatEqual(expected, actual)
282         black.assert_stable(source, actual, DEFAULT_MODE)
283         if sys.version_info >= (3, 8):
284             black.assert_equivalent(source, actual)
285
286     @patch("black.dump_to_file", dump_to_stderr)
287     def test_pep_572_do_not_remove_parens(self) -> None:
288         source, expected = read_data("pep_572_do_not_remove_parens")
289         # the AST safety checks will fail, but that's expected, just make sure no
290         # parentheses are touched
291         actual = black.format_str(source, mode=DEFAULT_MODE)
292         self.assertFormatEqual(expected, actual)
293
294     def test_pep_572_version_detection(self) -> None:
295         source, _ = read_data("pep_572")
296         root = black.lib2to3_parse(source)
297         features = black.get_features_used(root)
298         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
299         versions = black.detect_target_versions(root)
300         self.assertIn(black.TargetVersion.PY38, versions)
301
302     def test_expression_ff(self) -> None:
303         source, expected = read_data("expression")
304         tmp_file = Path(black.dump_to_file(source))
305         try:
306             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
307             with open(tmp_file, encoding="utf8") as f:
308                 actual = f.read()
309         finally:
310             os.unlink(tmp_file)
311         self.assertFormatEqual(expected, actual)
312         with patch("black.dump_to_file", dump_to_stderr):
313             black.assert_equivalent(source, actual)
314             black.assert_stable(source, actual, DEFAULT_MODE)
315
316     def test_expression_diff(self) -> None:
317         source, _ = read_data("expression.py")
318         config = THIS_DIR / "data" / "empty_pyproject.toml"
319         expected, _ = read_data("expression.diff")
320         tmp_file = Path(black.dump_to_file(source))
321         diff_header = re.compile(
322             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
323             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
324         )
325         try:
326             result = BlackRunner().invoke(
327                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
328             )
329             self.assertEqual(result.exit_code, 0)
330         finally:
331             os.unlink(tmp_file)
332         actual = result.output
333         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
334         if expected != actual:
335             dump = black.dump_to_file(actual)
336             msg = (
337                 "Expected diff isn't equal to the actual. If you made changes to"
338                 " expression.py and this is an anticipated difference, overwrite"
339                 f" tests/data/expression.diff with {dump}"
340             )
341             self.assertEqual(expected, actual, msg)
342
343     def test_expression_diff_with_color(self) -> None:
344         source, _ = read_data("expression.py")
345         config = THIS_DIR / "data" / "empty_pyproject.toml"
346         expected, _ = read_data("expression.diff")
347         tmp_file = Path(black.dump_to_file(source))
348         try:
349             result = BlackRunner().invoke(
350                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
351             )
352         finally:
353             os.unlink(tmp_file)
354         actual = result.output
355         # We check the contents of the diff in `test_expression_diff`. All
356         # we need to check here is that color codes exist in the result.
357         self.assertIn("\033[1;37m", actual)
358         self.assertIn("\033[36m", actual)
359         self.assertIn("\033[32m", actual)
360         self.assertIn("\033[31m", actual)
361         self.assertIn("\033[0m", actual)
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_pep_570(self) -> None:
365         source, expected = read_data("pep_570")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_stable(source, actual, DEFAULT_MODE)
369         if sys.version_info >= (3, 8):
370             black.assert_equivalent(source, actual)
371
372     def test_detect_pos_only_arguments(self) -> None:
373         source, _ = read_data("pep_570")
374         root = black.lib2to3_parse(source)
375         features = black.get_features_used(root)
376         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
377         versions = black.detect_target_versions(root)
378         self.assertIn(black.TargetVersion.PY38, versions)
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_string_quotes(self) -> None:
382         source, expected = read_data("string_quotes")
383         mode = black.Mode(experimental_string_processing=True)
384         actual = fs(source, mode=mode)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, mode)
388         mode = replace(mode, string_normalization=False)
389         not_normalized = fs(source, mode=mode)
390         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
391         black.assert_equivalent(source, not_normalized)
392         black.assert_stable(source, not_normalized, mode=mode)
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_docstring_no_string_normalization(self) -> None:
396         """Like test_docstring but with string normalization off."""
397         source, expected = read_data("docstring_no_string_normalization")
398         mode = replace(DEFAULT_MODE, string_normalization=False)
399         actual = fs(source, mode=mode)
400         self.assertFormatEqual(expected, actual)
401         black.assert_equivalent(source, actual)
402         black.assert_stable(source, actual, mode)
403
404     def test_long_strings_flag_disabled(self) -> None:
405         """Tests for turning off the string processing logic."""
406         source, expected = read_data("long_strings_flag_disabled")
407         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
408         actual = fs(source, mode=mode)
409         self.assertFormatEqual(expected, actual)
410         black.assert_stable(expected, actual, mode)
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_numeric_literals(self) -> None:
414         source, expected = read_data("numeric_literals")
415         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
416         actual = fs(source, mode=mode)
417         self.assertFormatEqual(expected, actual)
418         black.assert_equivalent(source, actual)
419         black.assert_stable(source, actual, mode)
420
421     @patch("black.dump_to_file", dump_to_stderr)
422     def test_numeric_literals_ignoring_underscores(self) -> None:
423         source, expected = read_data("numeric_literals_skip_underscores")
424         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
425         actual = fs(source, mode=mode)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, mode)
429
430     def test_skip_magic_trailing_comma(self) -> None:
431         source, _ = read_data("expression.py")
432         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
433         tmp_file = Path(black.dump_to_file(source))
434         diff_header = re.compile(
435             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
436             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
437         )
438         try:
439             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
440             self.assertEqual(result.exit_code, 0)
441         finally:
442             os.unlink(tmp_file)
443         actual = result.output
444         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
445         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
446         if expected != actual:
447             dump = black.dump_to_file(actual)
448             msg = (
449                 "Expected diff isn't equal to the actual. If you made changes to"
450                 " expression.py and this is an anticipated difference, overwrite"
451                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
452             )
453             self.assertEqual(expected, actual, msg)
454
455     @pytest.mark.python2
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2_print_function(self) -> None:
458         source, expected = read_data("python2_print_function")
459         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
460         actual = fs(source, mode=mode)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, mode)
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_stub(self) -> None:
467         mode = replace(DEFAULT_MODE, is_pyi=True)
468         source, expected = read_data("stub.pyi")
469         actual = fs(source, mode=mode)
470         self.assertFormatEqual(expected, actual)
471         black.assert_stable(source, actual, mode)
472
473     @patch("black.dump_to_file", dump_to_stderr)
474     def test_async_as_identifier(self) -> None:
475         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
476         source, expected = read_data("async_as_identifier")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         major, minor = sys.version_info[:2]
480         if major < 3 or (major <= 3 and minor < 7):
481             black.assert_equivalent(source, actual)
482         black.assert_stable(source, actual, DEFAULT_MODE)
483         # ensure black can parse this when the target is 3.6
484         self.invokeBlack([str(source_path), "--target-version", "py36"])
485         # but not on 3.7, because async/await is no longer an identifier
486         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_python37(self) -> None:
490         source_path = (THIS_DIR / "data" / "python37.py").resolve()
491         source, expected = read_data("python37")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         major, minor = sys.version_info[:2]
495         if major > 3 or (major == 3 and minor >= 7):
496             black.assert_equivalent(source, actual)
497         black.assert_stable(source, actual, DEFAULT_MODE)
498         # ensure black can parse this when the target is 3.7
499         self.invokeBlack([str(source_path), "--target-version", "py37"])
500         # but not on 3.6, because we use async as a reserved keyword
501         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
502
503     @patch("black.dump_to_file", dump_to_stderr)
504     def test_python38(self) -> None:
505         source, expected = read_data("python38")
506         actual = fs(source)
507         self.assertFormatEqual(expected, actual)
508         major, minor = sys.version_info[:2]
509         if major > 3 or (major == 3 and minor >= 8):
510             black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, DEFAULT_MODE)
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_python39(self) -> None:
515         source, expected = read_data("python39")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         major, minor = sys.version_info[:2]
519         if major > 3 or (major == 3 and minor >= 9):
520             black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, DEFAULT_MODE)
522
523     def test_tab_comment_indentation(self) -> None:
524         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
525         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
526         self.assertFormatEqual(contents_spc, fs(contents_spc))
527         self.assertFormatEqual(contents_spc, fs(contents_tab))
528
529         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
530         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
531         self.assertFormatEqual(contents_spc, fs(contents_spc))
532         self.assertFormatEqual(contents_spc, fs(contents_tab))
533
534         # mixed tabs and spaces (valid Python 2 code)
535         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
536         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
537         self.assertFormatEqual(contents_spc, fs(contents_spc))
538         self.assertFormatEqual(contents_spc, fs(contents_tab))
539
540         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
541         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
542         self.assertFormatEqual(contents_spc, fs(contents_spc))
543         self.assertFormatEqual(contents_spc, fs(contents_tab))
544
545     def test_report_verbose(self) -> None:
546         report = Report(verbose=True)
547         out_lines = []
548         err_lines = []
549
550         def out(msg: str, **kwargs: Any) -> None:
551             out_lines.append(msg)
552
553         def err(msg: str, **kwargs: Any) -> None:
554             err_lines.append(msg)
555
556         with patch("black.output._out", out), patch("black.output._err", err):
557             report.done(Path("f1"), black.Changed.NO)
558             self.assertEqual(len(out_lines), 1)
559             self.assertEqual(len(err_lines), 0)
560             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
561             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
562             self.assertEqual(report.return_code, 0)
563             report.done(Path("f2"), black.Changed.YES)
564             self.assertEqual(len(out_lines), 2)
565             self.assertEqual(len(err_lines), 0)
566             self.assertEqual(out_lines[-1], "reformatted f2")
567             self.assertEqual(
568                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
569             )
570             report.done(Path("f3"), black.Changed.CACHED)
571             self.assertEqual(len(out_lines), 3)
572             self.assertEqual(len(err_lines), 0)
573             self.assertEqual(
574                 out_lines[-1], "f3 wasn't modified on disk since last run."
575             )
576             self.assertEqual(
577                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
578             )
579             self.assertEqual(report.return_code, 0)
580             report.check = True
581             self.assertEqual(report.return_code, 1)
582             report.check = False
583             report.failed(Path("e1"), "boom")
584             self.assertEqual(len(out_lines), 3)
585             self.assertEqual(len(err_lines), 1)
586             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
587             self.assertEqual(
588                 unstyle(str(report)),
589                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
590                 " reformat.",
591             )
592             self.assertEqual(report.return_code, 123)
593             report.done(Path("f3"), black.Changed.YES)
594             self.assertEqual(len(out_lines), 4)
595             self.assertEqual(len(err_lines), 1)
596             self.assertEqual(out_lines[-1], "reformatted f3")
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
600                 " reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.failed(Path("e2"), "boom")
604             self.assertEqual(len(out_lines), 4)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
607             self.assertEqual(
608                 unstyle(str(report)),
609                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
610                 " reformat.",
611             )
612             self.assertEqual(report.return_code, 123)
613             report.path_ignored(Path("wat"), "no match")
614             self.assertEqual(len(out_lines), 5)
615             self.assertEqual(len(err_lines), 2)
616             self.assertEqual(out_lines[-1], "wat ignored: no match")
617             self.assertEqual(
618                 unstyle(str(report)),
619                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
620                 " reformat.",
621             )
622             self.assertEqual(report.return_code, 123)
623             report.done(Path("f4"), black.Changed.NO)
624             self.assertEqual(len(out_lines), 6)
625             self.assertEqual(len(err_lines), 2)
626             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
627             self.assertEqual(
628                 unstyle(str(report)),
629                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
630                 " reformat.",
631             )
632             self.assertEqual(report.return_code, 123)
633             report.check = True
634             self.assertEqual(
635                 unstyle(str(report)),
636                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
637                 " would fail to reformat.",
638             )
639             report.check = False
640             report.diff = True
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
644                 " would fail to reformat.",
645             )
646
647     def test_report_quiet(self) -> None:
648         report = Report(quiet=True)
649         out_lines = []
650         err_lines = []
651
652         def out(msg: str, **kwargs: Any) -> None:
653             out_lines.append(msg)
654
655         def err(msg: str, **kwargs: Any) -> None:
656             err_lines.append(msg)
657
658         with patch("black.output._out", out), patch("black.output._err", err):
659             report.done(Path("f1"), black.Changed.NO)
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 0)
662             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
663             self.assertEqual(report.return_code, 0)
664             report.done(Path("f2"), black.Changed.YES)
665             self.assertEqual(len(out_lines), 0)
666             self.assertEqual(len(err_lines), 0)
667             self.assertEqual(
668                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
669             )
670             report.done(Path("f3"), black.Changed.CACHED)
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 0)
673             self.assertEqual(
674                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
675             )
676             self.assertEqual(report.return_code, 0)
677             report.check = True
678             self.assertEqual(report.return_code, 1)
679             report.check = False
680             report.failed(Path("e1"), "boom")
681             self.assertEqual(len(out_lines), 0)
682             self.assertEqual(len(err_lines), 1)
683             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
684             self.assertEqual(
685                 unstyle(str(report)),
686                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
687                 " reformat.",
688             )
689             self.assertEqual(report.return_code, 123)
690             report.done(Path("f3"), black.Changed.YES)
691             self.assertEqual(len(out_lines), 0)
692             self.assertEqual(len(err_lines), 1)
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
696                 " reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.failed(Path("e2"), "boom")
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 2)
702             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
703             self.assertEqual(
704                 unstyle(str(report)),
705                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
706                 " reformat.",
707             )
708             self.assertEqual(report.return_code, 123)
709             report.path_ignored(Path("wat"), "no match")
710             self.assertEqual(len(out_lines), 0)
711             self.assertEqual(len(err_lines), 2)
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
715                 " reformat.",
716             )
717             self.assertEqual(report.return_code, 123)
718             report.done(Path("f4"), black.Changed.NO)
719             self.assertEqual(len(out_lines), 0)
720             self.assertEqual(len(err_lines), 2)
721             self.assertEqual(
722                 unstyle(str(report)),
723                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
724                 " reformat.",
725             )
726             self.assertEqual(report.return_code, 123)
727             report.check = True
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
731                 " would fail to reformat.",
732             )
733             report.check = False
734             report.diff = True
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
738                 " would fail to reformat.",
739             )
740
741     def test_report_normal(self) -> None:
742         report = black.Report()
743         out_lines = []
744         err_lines = []
745
746         def out(msg: str, **kwargs: Any) -> None:
747             out_lines.append(msg)
748
749         def err(msg: str, **kwargs: Any) -> None:
750             err_lines.append(msg)
751
752         with patch("black.output._out", out), patch("black.output._err", err):
753             report.done(Path("f1"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
757             self.assertEqual(report.return_code, 0)
758             report.done(Path("f2"), black.Changed.YES)
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 0)
761             self.assertEqual(out_lines[-1], "reformatted f2")
762             self.assertEqual(
763                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
764             )
765             report.done(Path("f3"), black.Changed.CACHED)
766             self.assertEqual(len(out_lines), 1)
767             self.assertEqual(len(err_lines), 0)
768             self.assertEqual(out_lines[-1], "reformatted f2")
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
771             )
772             self.assertEqual(report.return_code, 0)
773             report.check = True
774             self.assertEqual(report.return_code, 1)
775             report.check = False
776             report.failed(Path("e1"), "boom")
777             self.assertEqual(len(out_lines), 1)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
783                 " reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.done(Path("f3"), black.Changed.YES)
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 1)
789             self.assertEqual(out_lines[-1], "reformatted f3")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
793                 " reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.failed(Path("e2"), "boom")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
803                 " reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.path_ignored(Path("wat"), "no match")
807             self.assertEqual(len(out_lines), 2)
808             self.assertEqual(len(err_lines), 2)
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
812                 " reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.done(Path("f4"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
821                 " reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.check = True
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
828                 " would fail to reformat.",
829             )
830             report.check = False
831             report.diff = True
832             self.assertEqual(
833                 unstyle(str(report)),
834                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
835                 " would fail to reformat.",
836             )
837
838     def test_lib2to3_parse(self) -> None:
839         with self.assertRaises(black.InvalidInput):
840             black.lib2to3_parse("invalid syntax")
841
842         straddling = "x + y"
843         black.lib2to3_parse(straddling)
844         black.lib2to3_parse(straddling, {TargetVersion.PY27})
845         black.lib2to3_parse(straddling, {TargetVersion.PY36})
846         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
847
848         py2_only = "print x"
849         black.lib2to3_parse(py2_only)
850         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
851         with self.assertRaises(black.InvalidInput):
852             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
853         with self.assertRaises(black.InvalidInput):
854             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
855
856         py3_only = "exec(x, end=y)"
857         black.lib2to3_parse(py3_only)
858         with self.assertRaises(black.InvalidInput):
859             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
860         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
861         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
862
863     def test_get_features_used_decorator(self) -> None:
864         # Test the feature detection of new decorator syntax
865         # since this makes some test cases of test_get_features_used()
866         # fails if it fails, this is tested first so that a useful case
867         # is identified
868         simples, relaxed = read_data("decorators")
869         # skip explanation comments at the top of the file
870         for simple_test in simples.split("##")[1:]:
871             node = black.lib2to3_parse(simple_test)
872             decorator = str(node.children[0].children[0]).strip()
873             self.assertNotIn(
874                 Feature.RELAXED_DECORATORS,
875                 black.get_features_used(node),
876                 msg=(
877                     f"decorator '{decorator}' follows python<=3.8 syntax"
878                     "but is detected as 3.9+"
879                     # f"The full node is\n{node!r}"
880                 ),
881             )
882         # skip the '# output' comment at the top of the output part
883         for relaxed_test in relaxed.split("##")[1:]:
884             node = black.lib2to3_parse(relaxed_test)
885             decorator = str(node.children[0].children[0]).strip()
886             self.assertIn(
887                 Feature.RELAXED_DECORATORS,
888                 black.get_features_used(node),
889                 msg=(
890                     f"decorator '{decorator}' uses python3.9+ syntax"
891                     "but is detected as python<=3.8"
892                     # f"The full node is\n{node!r}"
893                 ),
894             )
895
896     def test_get_features_used(self) -> None:
897         node = black.lib2to3_parse("def f(*, arg): ...\n")
898         self.assertEqual(black.get_features_used(node), set())
899         node = black.lib2to3_parse("def f(*, arg,): ...\n")
900         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
901         node = black.lib2to3_parse("f(*arg,)\n")
902         self.assertEqual(
903             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
904         )
905         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
906         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
907         node = black.lib2to3_parse("123_456\n")
908         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
909         node = black.lib2to3_parse("123456\n")
910         self.assertEqual(black.get_features_used(node), set())
911         source, expected = read_data("function")
912         node = black.lib2to3_parse(source)
913         expected_features = {
914             Feature.TRAILING_COMMA_IN_CALL,
915             Feature.TRAILING_COMMA_IN_DEF,
916             Feature.F_STRINGS,
917         }
918         self.assertEqual(black.get_features_used(node), expected_features)
919         node = black.lib2to3_parse(expected)
920         self.assertEqual(black.get_features_used(node), expected_features)
921         source, expected = read_data("expression")
922         node = black.lib2to3_parse(source)
923         self.assertEqual(black.get_features_used(node), set())
924         node = black.lib2to3_parse(expected)
925         self.assertEqual(black.get_features_used(node), set())
926
927     def test_get_future_imports(self) -> None:
928         node = black.lib2to3_parse("\n")
929         self.assertEqual(set(), black.get_future_imports(node))
930         node = black.lib2to3_parse("from __future__ import black\n")
931         self.assertEqual({"black"}, black.get_future_imports(node))
932         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
933         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
934         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
935         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
936         node = black.lib2to3_parse(
937             "from __future__ import multiple\nfrom __future__ import imports\n"
938         )
939         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
940         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
941         self.assertEqual({"black"}, black.get_future_imports(node))
942         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
943         self.assertEqual({"black"}, black.get_future_imports(node))
944         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
945         self.assertEqual(set(), black.get_future_imports(node))
946         node = black.lib2to3_parse("from some.module import black\n")
947         self.assertEqual(set(), black.get_future_imports(node))
948         node = black.lib2to3_parse(
949             "from __future__ import unicode_literals as _unicode_literals"
950         )
951         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
952         node = black.lib2to3_parse(
953             "from __future__ import unicode_literals as _lol, print"
954         )
955         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
956
957     def test_debug_visitor(self) -> None:
958         source, _ = read_data("debug_visitor.py")
959         expected, _ = read_data("debug_visitor.out")
960         out_lines = []
961         err_lines = []
962
963         def out(msg: str, **kwargs: Any) -> None:
964             out_lines.append(msg)
965
966         def err(msg: str, **kwargs: Any) -> None:
967             err_lines.append(msg)
968
969         with patch("black.debug.out", out):
970             DebugVisitor.show(source)
971         actual = "\n".join(out_lines) + "\n"
972         log_name = ""
973         if expected != actual:
974             log_name = black.dump_to_file(*out_lines)
975         self.assertEqual(
976             expected,
977             actual,
978             f"AST print out is different. Actual version dumped to {log_name}",
979         )
980
981     def test_format_file_contents(self) -> None:
982         empty = ""
983         mode = DEFAULT_MODE
984         with self.assertRaises(black.NothingChanged):
985             black.format_file_contents(empty, mode=mode, fast=False)
986         just_nl = "\n"
987         with self.assertRaises(black.NothingChanged):
988             black.format_file_contents(just_nl, mode=mode, fast=False)
989         same = "j = [1, 2, 3]\n"
990         with self.assertRaises(black.NothingChanged):
991             black.format_file_contents(same, mode=mode, fast=False)
992         different = "j = [1,2,3]"
993         expected = same
994         actual = black.format_file_contents(different, mode=mode, fast=False)
995         self.assertEqual(expected, actual)
996         invalid = "return if you can"
997         with self.assertRaises(black.InvalidInput) as e:
998             black.format_file_contents(invalid, mode=mode, fast=False)
999         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1000
1001     def test_endmarker(self) -> None:
1002         n = black.lib2to3_parse("\n")
1003         self.assertEqual(n.type, black.syms.file_input)
1004         self.assertEqual(len(n.children), 1)
1005         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1006
1007     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1008     def test_assertFormatEqual(self) -> None:
1009         out_lines = []
1010         err_lines = []
1011
1012         def out(msg: str, **kwargs: Any) -> None:
1013             out_lines.append(msg)
1014
1015         def err(msg: str, **kwargs: Any) -> None:
1016             err_lines.append(msg)
1017
1018         with patch("black.output._out", out), patch("black.output._err", err):
1019             with self.assertRaises(AssertionError):
1020                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1021
1022         out_str = "".join(out_lines)
1023         self.assertTrue("Expected tree:" in out_str)
1024         self.assertTrue("Actual tree:" in out_str)
1025         self.assertEqual("".join(err_lines), "")
1026
1027     def test_cache_broken_file(self) -> None:
1028         mode = DEFAULT_MODE
1029         with cache_dir() as workspace:
1030             cache_file = get_cache_file(mode)
1031             with cache_file.open("w") as fobj:
1032                 fobj.write("this is not a pickle")
1033             self.assertEqual(black.read_cache(mode), {})
1034             src = (workspace / "test.py").resolve()
1035             with src.open("w") as fobj:
1036                 fobj.write("print('hello')")
1037             self.invokeBlack([str(src)])
1038             cache = black.read_cache(mode)
1039             self.assertIn(str(src), cache)
1040
1041     def test_cache_single_file_already_cached(self) -> None:
1042         mode = DEFAULT_MODE
1043         with cache_dir() as workspace:
1044             src = (workspace / "test.py").resolve()
1045             with src.open("w") as fobj:
1046                 fobj.write("print('hello')")
1047             black.write_cache({}, [src], mode)
1048             self.invokeBlack([str(src)])
1049             with src.open("r") as fobj:
1050                 self.assertEqual(fobj.read(), "print('hello')")
1051
1052     @event_loop()
1053     def test_cache_multiple_files(self) -> None:
1054         mode = DEFAULT_MODE
1055         with cache_dir() as workspace, patch(
1056             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1057         ):
1058             one = (workspace / "one.py").resolve()
1059             with one.open("w") as fobj:
1060                 fobj.write("print('hello')")
1061             two = (workspace / "two.py").resolve()
1062             with two.open("w") as fobj:
1063                 fobj.write("print('hello')")
1064             black.write_cache({}, [one], mode)
1065             self.invokeBlack([str(workspace)])
1066             with one.open("r") as fobj:
1067                 self.assertEqual(fobj.read(), "print('hello')")
1068             with two.open("r") as fobj:
1069                 self.assertEqual(fobj.read(), 'print("hello")\n')
1070             cache = black.read_cache(mode)
1071             self.assertIn(str(one), cache)
1072             self.assertIn(str(two), cache)
1073
1074     def test_no_cache_when_writeback_diff(self) -> None:
1075         mode = DEFAULT_MODE
1076         with cache_dir() as workspace:
1077             src = (workspace / "test.py").resolve()
1078             with src.open("w") as fobj:
1079                 fobj.write("print('hello')")
1080             with patch("black.read_cache") as read_cache, patch(
1081                 "black.write_cache"
1082             ) as write_cache:
1083                 self.invokeBlack([str(src), "--diff"])
1084                 cache_file = get_cache_file(mode)
1085                 self.assertFalse(cache_file.exists())
1086                 write_cache.assert_not_called()
1087                 read_cache.assert_not_called()
1088
1089     def test_no_cache_when_writeback_color_diff(self) -> None:
1090         mode = DEFAULT_MODE
1091         with cache_dir() as workspace:
1092             src = (workspace / "test.py").resolve()
1093             with src.open("w") as fobj:
1094                 fobj.write("print('hello')")
1095             with patch("black.read_cache") as read_cache, patch(
1096                 "black.write_cache"
1097             ) as write_cache:
1098                 self.invokeBlack([str(src), "--diff", "--color"])
1099                 cache_file = get_cache_file(mode)
1100                 self.assertFalse(cache_file.exists())
1101                 write_cache.assert_not_called()
1102                 read_cache.assert_not_called()
1103
1104     @event_loop()
1105     def test_output_locking_when_writeback_diff(self) -> None:
1106         with cache_dir() as workspace:
1107             for tag in range(0, 4):
1108                 src = (workspace / f"test{tag}.py").resolve()
1109                 with src.open("w") as fobj:
1110                     fobj.write("print('hello')")
1111             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1112                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1113                 # this isn't quite doing what we want, but if it _isn't_
1114                 # called then we cannot be using the lock it provides
1115                 mgr.assert_called()
1116
1117     @event_loop()
1118     def test_output_locking_when_writeback_color_diff(self) -> None:
1119         with cache_dir() as workspace:
1120             for tag in range(0, 4):
1121                 src = (workspace / f"test{tag}.py").resolve()
1122                 with src.open("w") as fobj:
1123                     fobj.write("print('hello')")
1124             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1125                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1126                 # this isn't quite doing what we want, but if it _isn't_
1127                 # called then we cannot be using the lock it provides
1128                 mgr.assert_called()
1129
1130     def test_no_cache_when_stdin(self) -> None:
1131         mode = DEFAULT_MODE
1132         with cache_dir():
1133             result = CliRunner().invoke(
1134                 black.main, ["-"], input=BytesIO(b"print('hello')")
1135             )
1136             self.assertEqual(result.exit_code, 0)
1137             cache_file = get_cache_file(mode)
1138             self.assertFalse(cache_file.exists())
1139
1140     def test_read_cache_no_cachefile(self) -> None:
1141         mode = DEFAULT_MODE
1142         with cache_dir():
1143             self.assertEqual(black.read_cache(mode), {})
1144
1145     def test_write_cache_read_cache(self) -> None:
1146         mode = DEFAULT_MODE
1147         with cache_dir() as workspace:
1148             src = (workspace / "test.py").resolve()
1149             src.touch()
1150             black.write_cache({}, [src], mode)
1151             cache = black.read_cache(mode)
1152             self.assertIn(str(src), cache)
1153             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1154
1155     def test_filter_cached(self) -> None:
1156         with TemporaryDirectory() as workspace:
1157             path = Path(workspace)
1158             uncached = (path / "uncached").resolve()
1159             cached = (path / "cached").resolve()
1160             cached_but_changed = (path / "changed").resolve()
1161             uncached.touch()
1162             cached.touch()
1163             cached_but_changed.touch()
1164             cache = {
1165                 str(cached): black.get_cache_info(cached),
1166                 str(cached_but_changed): (0.0, 0),
1167             }
1168             todo, done = black.filter_cached(
1169                 cache, {uncached, cached, cached_but_changed}
1170             )
1171             self.assertEqual(todo, {uncached, cached_but_changed})
1172             self.assertEqual(done, {cached})
1173
1174     def test_write_cache_creates_directory_if_needed(self) -> None:
1175         mode = DEFAULT_MODE
1176         with cache_dir(exists=False) as workspace:
1177             self.assertFalse(workspace.exists())
1178             black.write_cache({}, [], mode)
1179             self.assertTrue(workspace.exists())
1180
1181     @event_loop()
1182     def test_failed_formatting_does_not_get_cached(self) -> None:
1183         mode = DEFAULT_MODE
1184         with cache_dir() as workspace, patch(
1185             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1186         ):
1187             failing = (workspace / "failing.py").resolve()
1188             with failing.open("w") as fobj:
1189                 fobj.write("not actually python")
1190             clean = (workspace / "clean.py").resolve()
1191             with clean.open("w") as fobj:
1192                 fobj.write('print("hello")\n')
1193             self.invokeBlack([str(workspace)], exit_code=123)
1194             cache = black.read_cache(mode)
1195             self.assertNotIn(str(failing), cache)
1196             self.assertIn(str(clean), cache)
1197
1198     def test_write_cache_write_fail(self) -> None:
1199         mode = DEFAULT_MODE
1200         with cache_dir(), patch.object(Path, "open") as mock:
1201             mock.side_effect = OSError
1202             black.write_cache({}, [], mode)
1203
1204     @event_loop()
1205     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1206     def test_works_in_mono_process_only_environment(self) -> None:
1207         with cache_dir() as workspace:
1208             for f in [
1209                 (workspace / "one.py").resolve(),
1210                 (workspace / "two.py").resolve(),
1211             ]:
1212                 f.write_text('print("hello")\n')
1213             self.invokeBlack([str(workspace)])
1214
1215     @event_loop()
1216     def test_check_diff_use_together(self) -> None:
1217         with cache_dir():
1218             # Files which will be reformatted.
1219             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1220             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1221             # Files which will not be reformatted.
1222             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1223             self.invokeBlack([str(src2), "--diff", "--check"])
1224             # Multi file command.
1225             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1226
1227     def test_no_files(self) -> None:
1228         with cache_dir():
1229             # Without an argument, black exits with error code 0.
1230             self.invokeBlack([])
1231
1232     def test_broken_symlink(self) -> None:
1233         with cache_dir() as workspace:
1234             symlink = workspace / "broken_link.py"
1235             try:
1236                 symlink.symlink_to("nonexistent.py")
1237             except OSError as e:
1238                 self.skipTest(f"Can't create symlinks: {e}")
1239             self.invokeBlack([str(workspace.resolve())])
1240
1241     def test_read_cache_line_lengths(self) -> None:
1242         mode = DEFAULT_MODE
1243         short_mode = replace(DEFAULT_MODE, line_length=1)
1244         with cache_dir() as workspace:
1245             path = (workspace / "file.py").resolve()
1246             path.touch()
1247             black.write_cache({}, [path], mode)
1248             one = black.read_cache(mode)
1249             self.assertIn(str(path), one)
1250             two = black.read_cache(short_mode)
1251             self.assertNotIn(str(path), two)
1252
1253     def test_single_file_force_pyi(self) -> None:
1254         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1255         contents, expected = read_data("force_pyi")
1256         with cache_dir() as workspace:
1257             path = (workspace / "file.py").resolve()
1258             with open(path, "w") as fh:
1259                 fh.write(contents)
1260             self.invokeBlack([str(path), "--pyi"])
1261             with open(path, "r") as fh:
1262                 actual = fh.read()
1263             # verify cache with --pyi is separate
1264             pyi_cache = black.read_cache(pyi_mode)
1265             self.assertIn(str(path), pyi_cache)
1266             normal_cache = black.read_cache(DEFAULT_MODE)
1267             self.assertNotIn(str(path), normal_cache)
1268         self.assertFormatEqual(expected, actual)
1269         black.assert_equivalent(contents, actual)
1270         black.assert_stable(contents, actual, pyi_mode)
1271
1272     @event_loop()
1273     def test_multi_file_force_pyi(self) -> None:
1274         reg_mode = DEFAULT_MODE
1275         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1276         contents, expected = read_data("force_pyi")
1277         with cache_dir() as workspace:
1278             paths = [
1279                 (workspace / "file1.py").resolve(),
1280                 (workspace / "file2.py").resolve(),
1281             ]
1282             for path in paths:
1283                 with open(path, "w") as fh:
1284                     fh.write(contents)
1285             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1286             for path in paths:
1287                 with open(path, "r") as fh:
1288                     actual = fh.read()
1289                 self.assertEqual(actual, expected)
1290             # verify cache with --pyi is separate
1291             pyi_cache = black.read_cache(pyi_mode)
1292             normal_cache = black.read_cache(reg_mode)
1293             for path in paths:
1294                 self.assertIn(str(path), pyi_cache)
1295                 self.assertNotIn(str(path), normal_cache)
1296
1297     def test_pipe_force_pyi(self) -> None:
1298         source, expected = read_data("force_pyi")
1299         result = CliRunner().invoke(
1300             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1301         )
1302         self.assertEqual(result.exit_code, 0)
1303         actual = result.output
1304         self.assertFormatEqual(actual, expected)
1305
1306     def test_single_file_force_py36(self) -> None:
1307         reg_mode = DEFAULT_MODE
1308         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1309         source, expected = read_data("force_py36")
1310         with cache_dir() as workspace:
1311             path = (workspace / "file.py").resolve()
1312             with open(path, "w") as fh:
1313                 fh.write(source)
1314             self.invokeBlack([str(path), *PY36_ARGS])
1315             with open(path, "r") as fh:
1316                 actual = fh.read()
1317             # verify cache with --target-version is separate
1318             py36_cache = black.read_cache(py36_mode)
1319             self.assertIn(str(path), py36_cache)
1320             normal_cache = black.read_cache(reg_mode)
1321             self.assertNotIn(str(path), normal_cache)
1322         self.assertEqual(actual, expected)
1323
1324     @event_loop()
1325     def test_multi_file_force_py36(self) -> None:
1326         reg_mode = DEFAULT_MODE
1327         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1328         source, expected = read_data("force_py36")
1329         with cache_dir() as workspace:
1330             paths = [
1331                 (workspace / "file1.py").resolve(),
1332                 (workspace / "file2.py").resolve(),
1333             ]
1334             for path in paths:
1335                 with open(path, "w") as fh:
1336                     fh.write(source)
1337             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1338             for path in paths:
1339                 with open(path, "r") as fh:
1340                     actual = fh.read()
1341                 self.assertEqual(actual, expected)
1342             # verify cache with --target-version is separate
1343             pyi_cache = black.read_cache(py36_mode)
1344             normal_cache = black.read_cache(reg_mode)
1345             for path in paths:
1346                 self.assertIn(str(path), pyi_cache)
1347                 self.assertNotIn(str(path), normal_cache)
1348
1349     def test_pipe_force_py36(self) -> None:
1350         source, expected = read_data("force_py36")
1351         result = CliRunner().invoke(
1352             black.main,
1353             ["-", "-q", "--target-version=py36"],
1354             input=BytesIO(source.encode("utf8")),
1355         )
1356         self.assertEqual(result.exit_code, 0)
1357         actual = result.output
1358         self.assertFormatEqual(actual, expected)
1359
1360     def test_include_exclude(self) -> None:
1361         path = THIS_DIR / "data" / "include_exclude_tests"
1362         include = re.compile(r"\.pyi?$")
1363         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1364         report = black.Report()
1365         gitignore = PathSpec.from_lines("gitwildmatch", [])
1366         sources: List[Path] = []
1367         expected = [
1368             Path(path / "b/dont_exclude/a.py"),
1369             Path(path / "b/dont_exclude/a.pyi"),
1370         ]
1371         this_abs = THIS_DIR.resolve()
1372         sources.extend(
1373             black.gen_python_files(
1374                 path.iterdir(),
1375                 this_abs,
1376                 include,
1377                 exclude,
1378                 None,
1379                 None,
1380                 report,
1381                 gitignore,
1382                 verbose=False,
1383                 quiet=False,
1384             )
1385         )
1386         self.assertEqual(sorted(expected), sorted(sources))
1387
1388     def test_gitignore_used_as_default(self) -> None:
1389         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1390         include = re.compile(r"\.pyi?$")
1391         extend_exclude = re.compile(r"/exclude/")
1392         src = str(path / "b/")
1393         report = black.Report()
1394         expected: List[Path] = [
1395             path / "b/.definitely_exclude/a.py",
1396             path / "b/.definitely_exclude/a.pyi",
1397         ]
1398         sources = list(
1399             black.get_sources(
1400                 ctx=FakeContext(),
1401                 src=(src,),
1402                 quiet=True,
1403                 verbose=False,
1404                 include=include,
1405                 exclude=None,
1406                 extend_exclude=extend_exclude,
1407                 force_exclude=None,
1408                 report=report,
1409                 stdin_filename=None,
1410             )
1411         )
1412         self.assertEqual(sorted(expected), sorted(sources))
1413
1414     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1415     def test_exclude_for_issue_1572(self) -> None:
1416         # Exclude shouldn't touch files that were explicitly given to Black through the
1417         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1418         # https://github.com/psf/black/issues/1572
1419         path = THIS_DIR / "data" / "include_exclude_tests"
1420         include = ""
1421         exclude = r"/exclude/|a\.py"
1422         src = str(path / "b/exclude/a.py")
1423         report = black.Report()
1424         expected = [Path(path / "b/exclude/a.py")]
1425         sources = list(
1426             black.get_sources(
1427                 ctx=FakeContext(),
1428                 src=(src,),
1429                 quiet=True,
1430                 verbose=False,
1431                 include=re.compile(include),
1432                 exclude=re.compile(exclude),
1433                 extend_exclude=None,
1434                 force_exclude=None,
1435                 report=report,
1436                 stdin_filename=None,
1437             )
1438         )
1439         self.assertEqual(sorted(expected), sorted(sources))
1440
1441     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1442     def test_get_sources_with_stdin(self) -> None:
1443         include = ""
1444         exclude = r"/exclude/|a\.py"
1445         src = "-"
1446         report = black.Report()
1447         expected = [Path("-")]
1448         sources = list(
1449             black.get_sources(
1450                 ctx=FakeContext(),
1451                 src=(src,),
1452                 quiet=True,
1453                 verbose=False,
1454                 include=re.compile(include),
1455                 exclude=re.compile(exclude),
1456                 extend_exclude=None,
1457                 force_exclude=None,
1458                 report=report,
1459                 stdin_filename=None,
1460             )
1461         )
1462         self.assertEqual(sorted(expected), sorted(sources))
1463
1464     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1465     def test_get_sources_with_stdin_filename(self) -> None:
1466         include = ""
1467         exclude = r"/exclude/|a\.py"
1468         src = "-"
1469         report = black.Report()
1470         stdin_filename = str(THIS_DIR / "data/collections.py")
1471         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1472         sources = list(
1473             black.get_sources(
1474                 ctx=FakeContext(),
1475                 src=(src,),
1476                 quiet=True,
1477                 verbose=False,
1478                 include=re.compile(include),
1479                 exclude=re.compile(exclude),
1480                 extend_exclude=None,
1481                 force_exclude=None,
1482                 report=report,
1483                 stdin_filename=stdin_filename,
1484             )
1485         )
1486         self.assertEqual(sorted(expected), sorted(sources))
1487
1488     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1489     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1490         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1491         # file being passed directly. This is the same as
1492         # test_exclude_for_issue_1572
1493         path = THIS_DIR / "data" / "include_exclude_tests"
1494         include = ""
1495         exclude = r"/exclude/|a\.py"
1496         src = "-"
1497         report = black.Report()
1498         stdin_filename = str(path / "b/exclude/a.py")
1499         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1500         sources = list(
1501             black.get_sources(
1502                 ctx=FakeContext(),
1503                 src=(src,),
1504                 quiet=True,
1505                 verbose=False,
1506                 include=re.compile(include),
1507                 exclude=re.compile(exclude),
1508                 extend_exclude=None,
1509                 force_exclude=None,
1510                 report=report,
1511                 stdin_filename=stdin_filename,
1512             )
1513         )
1514         self.assertEqual(sorted(expected), sorted(sources))
1515
1516     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1517     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1518         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1519         # file being passed directly. This is the same as
1520         # test_exclude_for_issue_1572
1521         path = THIS_DIR / "data" / "include_exclude_tests"
1522         include = ""
1523         extend_exclude = r"/exclude/|a\.py"
1524         src = "-"
1525         report = black.Report()
1526         stdin_filename = str(path / "b/exclude/a.py")
1527         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1528         sources = list(
1529             black.get_sources(
1530                 ctx=FakeContext(),
1531                 src=(src,),
1532                 quiet=True,
1533                 verbose=False,
1534                 include=re.compile(include),
1535                 exclude=re.compile(""),
1536                 extend_exclude=re.compile(extend_exclude),
1537                 force_exclude=None,
1538                 report=report,
1539                 stdin_filename=stdin_filename,
1540             )
1541         )
1542         self.assertEqual(sorted(expected), sorted(sources))
1543
1544     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1545     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1546         # Force exclude should exclude the file when passing it through
1547         # stdin_filename
1548         path = THIS_DIR / "data" / "include_exclude_tests"
1549         include = ""
1550         force_exclude = r"/exclude/|a\.py"
1551         src = "-"
1552         report = black.Report()
1553         stdin_filename = str(path / "b/exclude/a.py")
1554         sources = list(
1555             black.get_sources(
1556                 ctx=FakeContext(),
1557                 src=(src,),
1558                 quiet=True,
1559                 verbose=False,
1560                 include=re.compile(include),
1561                 exclude=re.compile(""),
1562                 extend_exclude=None,
1563                 force_exclude=re.compile(force_exclude),
1564                 report=report,
1565                 stdin_filename=stdin_filename,
1566             )
1567         )
1568         self.assertEqual([], sorted(sources))
1569
1570     def test_reformat_one_with_stdin(self) -> None:
1571         with patch(
1572             "black.format_stdin_to_stdout",
1573             return_value=lambda *args, **kwargs: black.Changed.YES,
1574         ) as fsts:
1575             report = MagicMock()
1576             path = Path("-")
1577             black.reformat_one(
1578                 path,
1579                 fast=True,
1580                 write_back=black.WriteBack.YES,
1581                 mode=DEFAULT_MODE,
1582                 report=report,
1583             )
1584             fsts.assert_called_once()
1585             report.done.assert_called_with(path, black.Changed.YES)
1586
1587     def test_reformat_one_with_stdin_filename(self) -> None:
1588         with patch(
1589             "black.format_stdin_to_stdout",
1590             return_value=lambda *args, **kwargs: black.Changed.YES,
1591         ) as fsts:
1592             report = MagicMock()
1593             p = "foo.py"
1594             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1595             expected = Path(p)
1596             black.reformat_one(
1597                 path,
1598                 fast=True,
1599                 write_back=black.WriteBack.YES,
1600                 mode=DEFAULT_MODE,
1601                 report=report,
1602             )
1603             fsts.assert_called_once_with(
1604                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1605             )
1606             # __BLACK_STDIN_FILENAME__ should have been stripped
1607             report.done.assert_called_with(expected, black.Changed.YES)
1608
1609     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1610         with patch(
1611             "black.format_stdin_to_stdout",
1612             return_value=lambda *args, **kwargs: black.Changed.YES,
1613         ) as fsts:
1614             report = MagicMock()
1615             p = "foo.pyi"
1616             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1617             expected = Path(p)
1618             black.reformat_one(
1619                 path,
1620                 fast=True,
1621                 write_back=black.WriteBack.YES,
1622                 mode=DEFAULT_MODE,
1623                 report=report,
1624             )
1625             fsts.assert_called_once_with(
1626                 fast=True,
1627                 write_back=black.WriteBack.YES,
1628                 mode=replace(DEFAULT_MODE, is_pyi=True),
1629             )
1630             # __BLACK_STDIN_FILENAME__ should have been stripped
1631             report.done.assert_called_with(expected, black.Changed.YES)
1632
1633     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1634         with patch(
1635             "black.format_stdin_to_stdout",
1636             return_value=lambda *args, **kwargs: black.Changed.YES,
1637         ) as fsts:
1638             report = MagicMock()
1639             # Even with an existing file, since we are forcing stdin, black
1640             # should output to stdout and not modify the file inplace
1641             p = Path(str(THIS_DIR / "data/collections.py"))
1642             # Make sure is_file actually returns True
1643             self.assertTrue(p.is_file())
1644             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1645             expected = Path(p)
1646             black.reformat_one(
1647                 path,
1648                 fast=True,
1649                 write_back=black.WriteBack.YES,
1650                 mode=DEFAULT_MODE,
1651                 report=report,
1652             )
1653             fsts.assert_called_once()
1654             # __BLACK_STDIN_FILENAME__ should have been stripped
1655             report.done.assert_called_with(expected, black.Changed.YES)
1656
1657     def test_reformat_one_with_stdin_empty(self) -> None:
1658         output = io.StringIO()
1659         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1660             try:
1661                 black.format_stdin_to_stdout(
1662                     fast=True,
1663                     content="",
1664                     write_back=black.WriteBack.YES,
1665                     mode=DEFAULT_MODE,
1666                 )
1667             except io.UnsupportedOperation:
1668                 pass  # StringIO does not support detach
1669             assert output.getvalue() == ""
1670
1671     def test_gitignore_exclude(self) -> None:
1672         path = THIS_DIR / "data" / "include_exclude_tests"
1673         include = re.compile(r"\.pyi?$")
1674         exclude = re.compile(r"")
1675         report = black.Report()
1676         gitignore = PathSpec.from_lines(
1677             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1678         )
1679         sources: List[Path] = []
1680         expected = [
1681             Path(path / "b/dont_exclude/a.py"),
1682             Path(path / "b/dont_exclude/a.pyi"),
1683         ]
1684         this_abs = THIS_DIR.resolve()
1685         sources.extend(
1686             black.gen_python_files(
1687                 path.iterdir(),
1688                 this_abs,
1689                 include,
1690                 exclude,
1691                 None,
1692                 None,
1693                 report,
1694                 gitignore,
1695                 verbose=False,
1696                 quiet=False,
1697             )
1698         )
1699         self.assertEqual(sorted(expected), sorted(sources))
1700
1701     def test_nested_gitignore(self) -> None:
1702         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1703         include = re.compile(r"\.pyi?$")
1704         exclude = re.compile(r"")
1705         root_gitignore = black.files.get_gitignore(path)
1706         report = black.Report()
1707         expected: List[Path] = [
1708             Path(path / "x.py"),
1709             Path(path / "root/b.py"),
1710             Path(path / "root/c.py"),
1711             Path(path / "root/child/c.py"),
1712         ]
1713         this_abs = THIS_DIR.resolve()
1714         sources = list(
1715             black.gen_python_files(
1716                 path.iterdir(),
1717                 this_abs,
1718                 include,
1719                 exclude,
1720                 None,
1721                 None,
1722                 report,
1723                 root_gitignore,
1724                 verbose=False,
1725                 quiet=False,
1726             )
1727         )
1728         self.assertEqual(sorted(expected), sorted(sources))
1729
1730     def test_invalid_gitignore(self) -> None:
1731         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1732         empty_config = path / "pyproject.toml"
1733         result = BlackRunner().invoke(
1734             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1735         )
1736         assert result.exit_code == 1
1737         assert result.stderr_bytes is not None
1738
1739         gitignore = path / ".gitignore"
1740         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1741
1742     def test_invalid_nested_gitignore(self) -> None:
1743         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1744         empty_config = path / "pyproject.toml"
1745         result = BlackRunner().invoke(
1746             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1747         )
1748         assert result.exit_code == 1
1749         assert result.stderr_bytes is not None
1750
1751         gitignore = path / "a" / ".gitignore"
1752         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1753
1754     def test_empty_include(self) -> None:
1755         path = THIS_DIR / "data" / "include_exclude_tests"
1756         report = black.Report()
1757         gitignore = PathSpec.from_lines("gitwildmatch", [])
1758         empty = re.compile(r"")
1759         sources: List[Path] = []
1760         expected = [
1761             Path(path / "b/exclude/a.pie"),
1762             Path(path / "b/exclude/a.py"),
1763             Path(path / "b/exclude/a.pyi"),
1764             Path(path / "b/dont_exclude/a.pie"),
1765             Path(path / "b/dont_exclude/a.py"),
1766             Path(path / "b/dont_exclude/a.pyi"),
1767             Path(path / "b/.definitely_exclude/a.pie"),
1768             Path(path / "b/.definitely_exclude/a.py"),
1769             Path(path / "b/.definitely_exclude/a.pyi"),
1770             Path(path / ".gitignore"),
1771             Path(path / "pyproject.toml"),
1772         ]
1773         this_abs = THIS_DIR.resolve()
1774         sources.extend(
1775             black.gen_python_files(
1776                 path.iterdir(),
1777                 this_abs,
1778                 empty,
1779                 re.compile(black.DEFAULT_EXCLUDES),
1780                 None,
1781                 None,
1782                 report,
1783                 gitignore,
1784                 verbose=False,
1785                 quiet=False,
1786             )
1787         )
1788         self.assertEqual(sorted(expected), sorted(sources))
1789
1790     def test_extend_exclude(self) -> None:
1791         path = THIS_DIR / "data" / "include_exclude_tests"
1792         report = black.Report()
1793         gitignore = PathSpec.from_lines("gitwildmatch", [])
1794         sources: List[Path] = []
1795         expected = [
1796             Path(path / "b/exclude/a.py"),
1797             Path(path / "b/dont_exclude/a.py"),
1798         ]
1799         this_abs = THIS_DIR.resolve()
1800         sources.extend(
1801             black.gen_python_files(
1802                 path.iterdir(),
1803                 this_abs,
1804                 re.compile(black.DEFAULT_INCLUDES),
1805                 re.compile(r"\.pyi$"),
1806                 re.compile(r"\.definitely_exclude"),
1807                 None,
1808                 report,
1809                 gitignore,
1810                 verbose=False,
1811                 quiet=False,
1812             )
1813         )
1814         self.assertEqual(sorted(expected), sorted(sources))
1815
1816     def test_invalid_cli_regex(self) -> None:
1817         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1818             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1819
1820     def test_required_version_matches_version(self) -> None:
1821         self.invokeBlack(
1822             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1823         )
1824
1825     def test_required_version_does_not_match_version(self) -> None:
1826         self.invokeBlack(
1827             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1828         )
1829
1830     def test_preserves_line_endings(self) -> None:
1831         with TemporaryDirectory() as workspace:
1832             test_file = Path(workspace) / "test.py"
1833             for nl in ["\n", "\r\n"]:
1834                 contents = nl.join(["def f(  ):", "    pass"])
1835                 test_file.write_bytes(contents.encode())
1836                 ff(test_file, write_back=black.WriteBack.YES)
1837                 updated_contents: bytes = test_file.read_bytes()
1838                 self.assertIn(nl.encode(), updated_contents)
1839                 if nl == "\n":
1840                     self.assertNotIn(b"\r\n", updated_contents)
1841
1842     def test_preserves_line_endings_via_stdin(self) -> None:
1843         for nl in ["\n", "\r\n"]:
1844             contents = nl.join(["def f(  ):", "    pass"])
1845             runner = BlackRunner()
1846             result = runner.invoke(
1847                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1848             )
1849             self.assertEqual(result.exit_code, 0)
1850             output = result.stdout_bytes
1851             self.assertIn(nl.encode("utf8"), output)
1852             if nl == "\n":
1853                 self.assertNotIn(b"\r\n", output)
1854
1855     def test_assert_equivalent_different_asts(self) -> None:
1856         with self.assertRaises(AssertionError):
1857             black.assert_equivalent("{}", "None")
1858
1859     def test_symlink_out_of_root_directory(self) -> None:
1860         path = MagicMock()
1861         root = THIS_DIR.resolve()
1862         child = MagicMock()
1863         include = re.compile(black.DEFAULT_INCLUDES)
1864         exclude = re.compile(black.DEFAULT_EXCLUDES)
1865         report = black.Report()
1866         gitignore = PathSpec.from_lines("gitwildmatch", [])
1867         # `child` should behave like a symlink which resolved path is clearly
1868         # outside of the `root` directory.
1869         path.iterdir.return_value = [child]
1870         child.resolve.return_value = Path("/a/b/c")
1871         child.as_posix.return_value = "/a/b/c"
1872         child.is_symlink.return_value = True
1873         try:
1874             list(
1875                 black.gen_python_files(
1876                     path.iterdir(),
1877                     root,
1878                     include,
1879                     exclude,
1880                     None,
1881                     None,
1882                     report,
1883                     gitignore,
1884                     verbose=False,
1885                     quiet=False,
1886                 )
1887             )
1888         except ValueError as ve:
1889             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1890         path.iterdir.assert_called_once()
1891         child.resolve.assert_called_once()
1892         child.is_symlink.assert_called_once()
1893         # `child` should behave like a strange file which resolved path is clearly
1894         # outside of the `root` directory.
1895         child.is_symlink.return_value = False
1896         with self.assertRaises(ValueError):
1897             list(
1898                 black.gen_python_files(
1899                     path.iterdir(),
1900                     root,
1901                     include,
1902                     exclude,
1903                     None,
1904                     None,
1905                     report,
1906                     gitignore,
1907                     verbose=False,
1908                     quiet=False,
1909                 )
1910             )
1911         path.iterdir.assert_called()
1912         self.assertEqual(path.iterdir.call_count, 2)
1913         child.resolve.assert_called()
1914         self.assertEqual(child.resolve.call_count, 2)
1915         child.is_symlink.assert_called()
1916         self.assertEqual(child.is_symlink.call_count, 2)
1917
1918     def test_shhh_click(self) -> None:
1919         try:
1920             from click import _unicodefun
1921         except ModuleNotFoundError:
1922             self.skipTest("Incompatible Click version")
1923         if not hasattr(_unicodefun, "_verify_python3_env"):
1924             self.skipTest("Incompatible Click version")
1925         # First, let's see if Click is crashing with a preferred ASCII charset.
1926         with patch("locale.getpreferredencoding") as gpe:
1927             gpe.return_value = "ASCII"
1928             with self.assertRaises(RuntimeError):
1929                 _unicodefun._verify_python3_env()  # type: ignore
1930         # Now, let's silence Click...
1931         black.patch_click()
1932         # ...and confirm it's silent.
1933         with patch("locale.getpreferredencoding") as gpe:
1934             gpe.return_value = "ASCII"
1935             try:
1936                 _unicodefun._verify_python3_env()  # type: ignore
1937             except RuntimeError as re:
1938                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1939
1940     def test_root_logger_not_used_directly(self) -> None:
1941         def fail(*args: Any, **kwargs: Any) -> None:
1942             self.fail("Record created with root logger")
1943
1944         with patch.multiple(
1945             logging.root,
1946             debug=fail,
1947             info=fail,
1948             warning=fail,
1949             error=fail,
1950             critical=fail,
1951             log=fail,
1952         ):
1953             ff(THIS_DIR / "util.py")
1954
1955     def test_invalid_config_return_code(self) -> None:
1956         tmp_file = Path(black.dump_to_file())
1957         try:
1958             tmp_config = Path(black.dump_to_file())
1959             tmp_config.unlink()
1960             args = ["--config", str(tmp_config), str(tmp_file)]
1961             self.invokeBlack(args, exit_code=2, ignore_config=False)
1962         finally:
1963             tmp_file.unlink()
1964
1965     def test_parse_pyproject_toml(self) -> None:
1966         test_toml_file = THIS_DIR / "test.toml"
1967         config = black.parse_pyproject_toml(str(test_toml_file))
1968         self.assertEqual(config["verbose"], 1)
1969         self.assertEqual(config["check"], "no")
1970         self.assertEqual(config["diff"], "y")
1971         self.assertEqual(config["color"], True)
1972         self.assertEqual(config["line_length"], 79)
1973         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1974         self.assertEqual(config["exclude"], r"\.pyi?$")
1975         self.assertEqual(config["include"], r"\.py?$")
1976
1977     def test_read_pyproject_toml(self) -> None:
1978         test_toml_file = THIS_DIR / "test.toml"
1979         fake_ctx = FakeContext()
1980         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1981         config = fake_ctx.default_map
1982         self.assertEqual(config["verbose"], "1")
1983         self.assertEqual(config["check"], "no")
1984         self.assertEqual(config["diff"], "y")
1985         self.assertEqual(config["color"], "True")
1986         self.assertEqual(config["line_length"], "79")
1987         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1988         self.assertEqual(config["exclude"], r"\.pyi?$")
1989         self.assertEqual(config["include"], r"\.py?$")
1990
1991     def test_find_project_root(self) -> None:
1992         with TemporaryDirectory() as workspace:
1993             root = Path(workspace)
1994             test_dir = root / "test"
1995             test_dir.mkdir()
1996
1997             src_dir = root / "src"
1998             src_dir.mkdir()
1999
2000             root_pyproject = root / "pyproject.toml"
2001             root_pyproject.touch()
2002             src_pyproject = src_dir / "pyproject.toml"
2003             src_pyproject.touch()
2004             src_python = src_dir / "foo.py"
2005             src_python.touch()
2006
2007             self.assertEqual(
2008                 black.find_project_root((src_dir, test_dir)), root.resolve()
2009             )
2010             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
2011             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
2012
2013     @patch(
2014         "black.files.find_user_pyproject_toml",
2015         black.files.find_user_pyproject_toml.__wrapped__,
2016     )
2017     def test_find_user_pyproject_toml_linux(self) -> None:
2018         if system() == "Windows":
2019             return
2020
2021         # Test if XDG_CONFIG_HOME is checked
2022         with TemporaryDirectory() as workspace:
2023             tmp_user_config = Path(workspace) / "black"
2024             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
2025                 self.assertEqual(
2026                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
2027                 )
2028
2029         # Test fallback for XDG_CONFIG_HOME
2030         with patch.dict("os.environ"):
2031             os.environ.pop("XDG_CONFIG_HOME", None)
2032             fallback_user_config = Path("~/.config").expanduser() / "black"
2033             self.assertEqual(
2034                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
2035             )
2036
2037     def test_find_user_pyproject_toml_windows(self) -> None:
2038         if system() != "Windows":
2039             return
2040
2041         user_config_path = Path.home() / ".black"
2042         self.assertEqual(
2043             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2044         )
2045
2046     def test_bpo_33660_workaround(self) -> None:
2047         if system() == "Windows":
2048             return
2049
2050         # https://bugs.python.org/issue33660
2051         root = Path("/")
2052         with change_directory(root):
2053             path = Path("workspace") / "project"
2054             report = black.Report(verbose=True)
2055             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2056             self.assertEqual(normalized_path, "workspace/project")
2057
2058     def test_newline_comment_interaction(self) -> None:
2059         source = "class A:\\\r\n# type: ignore\n pass\n"
2060         output = black.format_str(source, mode=DEFAULT_MODE)
2061         black.assert_stable(source, output, mode=DEFAULT_MODE)
2062
2063     def test_bpo_2142_workaround(self) -> None:
2064
2065         # https://bugs.python.org/issue2142
2066
2067         source, _ = read_data("missing_final_newline.py")
2068         # read_data adds a trailing newline
2069         source = source.rstrip()
2070         expected, _ = read_data("missing_final_newline.diff")
2071         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2072         diff_header = re.compile(
2073             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2074             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2075         )
2076         try:
2077             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2078             self.assertEqual(result.exit_code, 0)
2079         finally:
2080             os.unlink(tmp_file)
2081         actual = result.output
2082         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2083         self.assertEqual(actual, expected)
2084
2085     @pytest.mark.python2
2086     def test_docstring_reformat_for_py27(self) -> None:
2087         """
2088         Check that stripping trailing whitespace from Python 2 docstrings
2089         doesn't trigger a "not equivalent to source" error
2090         """
2091         source = (
2092             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2093         )
2094         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2095
2096         result = CliRunner().invoke(
2097             black.main,
2098             ["-", "-q", "--target-version=py27"],
2099             input=BytesIO(source),
2100         )
2101
2102         self.assertEqual(result.exit_code, 0)
2103         actual = result.output
2104         self.assertFormatEqual(actual, expected)
2105
2106     @staticmethod
2107     def compare_results(
2108         result: click.testing.Result, expected_value: str, expected_exit_code: int
2109     ) -> None:
2110         """Helper method to test the value and exit code of a click Result."""
2111         assert (
2112             result.output == expected_value
2113         ), "The output did not match the expected value."
2114         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2115
2116     def test_code_option(self) -> None:
2117         """Test the code option with no changes."""
2118         code = 'print("Hello world")\n'
2119         args = ["--code", code]
2120         result = CliRunner().invoke(black.main, args)
2121
2122         self.compare_results(result, code, 0)
2123
2124     def test_code_option_changed(self) -> None:
2125         """Test the code option when changes are required."""
2126         code = "print('hello world')"
2127         formatted = black.format_str(code, mode=DEFAULT_MODE)
2128
2129         args = ["--code", code]
2130         result = CliRunner().invoke(black.main, args)
2131
2132         self.compare_results(result, formatted, 0)
2133
2134     def test_code_option_check(self) -> None:
2135         """Test the code option when check is passed."""
2136         args = ["--check", "--code", 'print("Hello world")\n']
2137         result = CliRunner().invoke(black.main, args)
2138         self.compare_results(result, "", 0)
2139
2140     def test_code_option_check_changed(self) -> None:
2141         """Test the code option when changes are required, and check is passed."""
2142         args = ["--check", "--code", "print('hello world')"]
2143         result = CliRunner().invoke(black.main, args)
2144         self.compare_results(result, "", 1)
2145
2146     def test_code_option_diff(self) -> None:
2147         """Test the code option when diff is passed."""
2148         code = "print('hello world')"
2149         formatted = black.format_str(code, mode=DEFAULT_MODE)
2150         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2151
2152         args = ["--diff", "--code", code]
2153         result = CliRunner().invoke(black.main, args)
2154
2155         # Remove time from diff
2156         output = DIFF_TIME.sub("", result.output)
2157
2158         assert output == result_diff, "The output did not match the expected value."
2159         assert result.exit_code == 0, "The exit code is incorrect."
2160
2161     def test_code_option_color_diff(self) -> None:
2162         """Test the code option when color and diff are passed."""
2163         code = "print('hello world')"
2164         formatted = black.format_str(code, mode=DEFAULT_MODE)
2165
2166         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2167         result_diff = color_diff(result_diff)
2168
2169         args = ["--diff", "--color", "--code", code]
2170         result = CliRunner().invoke(black.main, args)
2171
2172         # Remove time from diff
2173         output = DIFF_TIME.sub("", result.output)
2174
2175         assert output == result_diff, "The output did not match the expected value."
2176         assert result.exit_code == 0, "The exit code is incorrect."
2177
2178     def test_code_option_safe(self) -> None:
2179         """Test that the code option throws an error when the sanity checks fail."""
2180         # Patch black.assert_equivalent to ensure the sanity checks fail
2181         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2182             code = 'print("Hello world")'
2183             error_msg = f"{code}\nerror: cannot format <string>: \n"
2184
2185             args = ["--safe", "--code", code]
2186             result = CliRunner().invoke(black.main, args)
2187
2188             self.compare_results(result, error_msg, 123)
2189
2190     def test_code_option_fast(self) -> None:
2191         """Test that the code option ignores errors when the sanity checks fail."""
2192         # Patch black.assert_equivalent to ensure the sanity checks fail
2193         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2194             code = 'print("Hello world")'
2195             formatted = black.format_str(code, mode=DEFAULT_MODE)
2196
2197             args = ["--fast", "--code", code]
2198             result = CliRunner().invoke(black.main, args)
2199
2200             self.compare_results(result, formatted, 0)
2201
2202     def test_code_option_config(self) -> None:
2203         """
2204         Test that the code option finds the pyproject.toml in the current directory.
2205         """
2206         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2207             args = ["--code", "print"]
2208             CliRunner().invoke(black.main, args)
2209
2210             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2211             assert (
2212                 len(parse.mock_calls) >= 1
2213             ), "Expected config parse to be called with the current directory."
2214
2215             _, call_args, _ = parse.mock_calls[0]
2216             assert (
2217                 call_args[0].lower() == str(pyproject_path).lower()
2218             ), "Incorrect config loaded."
2219
2220     def test_code_option_parent_config(self) -> None:
2221         """
2222         Test that the code option finds the pyproject.toml in the parent directory.
2223         """
2224         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2225             with change_directory(Path("tests")):
2226                 args = ["--code", "print"]
2227                 CliRunner().invoke(black.main, args)
2228
2229                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2230                 assert (
2231                     len(parse.mock_calls) >= 1
2232                 ), "Expected config parse to be called with the current directory."
2233
2234                 _, call_args, _ = parse.mock_calls[0]
2235                 assert (
2236                     call_args[0].lower() == str(pyproject_path).lower()
2237                 ), "Incorrect config loaded."
2238
2239
2240 with open(black.__file__, "r", encoding="utf-8") as _bf:
2241     black_source_lines = _bf.readlines()
2242
2243
2244 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2245     """Show function calls `from black/__init__.py` as they happen.
2246
2247     Register this with `sys.settrace()` in a test you're debugging.
2248     """
2249     if event != "call":
2250         return tracefunc
2251
2252     stack = len(inspect.stack()) - 19
2253     stack *= 2
2254     filename = frame.f_code.co_filename
2255     lineno = frame.f_lineno
2256     func_sig_lineno = lineno - 1
2257     funcname = black_source_lines[func_sig_lineno].strip()
2258     while funcname.startswith("@"):
2259         func_sig_lineno += 1
2260         funcname = black_source_lines[func_sig_lineno].strip()
2261     if "black/__init__.py" in filename:
2262         print(f"{' ' * stack}{lineno}:{funcname}")
2263     return tracefunc
2264
2265
2266 if __name__ == "__main__":
2267     unittest.main(module="test_black")