]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

isort docs have changed urls (#2390)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 import io
10 from io import BytesIO
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     Callable,
21     Dict,
22     List,
23     Iterator,
24     TypeVar,
25 )
26 import pytest
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36 from black.cache import get_cache_file
37 from black.debug import DebugVisitor
38 from black.output import diff, color_diff
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     change_directory,
48     read_data,
49     DETERMINISTIC_HEADER,
50     BlackBaseTestCase,
51     DEFAULT_MODE,
52     fs,
53     ff,
54     dump_to_stderr,
55 )
56
57
58 THIS_FILE = Path(__file__)
59 PY36_VERSIONS = {
60     TargetVersion.PY36,
61     TargetVersion.PY37,
62     TargetVersion.PY38,
63     TargetVersion.PY39,
64 }
65 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
66 T = TypeVar("T")
67 R = TypeVar("R")
68
69 # Match the time output in a diff, but nothing else
70 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
71
72
73 @contextmanager
74 def cache_dir(exists: bool = True) -> Iterator[Path]:
75     with TemporaryDirectory() as workspace:
76         cache_dir = Path(workspace)
77         if not exists:
78             cache_dir = cache_dir / "new"
79         with patch("black.cache.CACHE_DIR", cache_dir):
80             yield cache_dir
81
82
83 @contextmanager
84 def event_loop() -> Iterator[None]:
85     policy = asyncio.get_event_loop_policy()
86     loop = policy.new_event_loop()
87     asyncio.set_event_loop(loop)
88     try:
89         yield
90
91     finally:
92         loop.close()
93
94
95 class FakeContext(click.Context):
96     """A fake click Context for when calling functions that need it."""
97
98     def __init__(self) -> None:
99         self.default_map: Dict[str, Any] = {}
100
101
102 class FakeParameter(click.Parameter):
103     """A fake click Parameter for when calling functions that need it."""
104
105     def __init__(self) -> None:
106         pass
107
108
109 class BlackRunner(CliRunner):
110     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
111
112     def __init__(self) -> None:
113         super().__init__(mix_stderr=False)
114
115
116 class BlackTestCase(BlackBaseTestCase):
117     def invokeBlack(
118         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
119     ) -> None:
120         runner = BlackRunner()
121         if ignore_config:
122             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
123         result = runner.invoke(black.main, args)
124         assert result.stdout_bytes is not None
125         assert result.stderr_bytes is not None
126         self.assertEqual(
127             result.exit_code,
128             exit_code,
129             msg=(
130                 f"Failed with args: {args}\n"
131                 f"stdout: {result.stdout_bytes.decode()!r}\n"
132                 f"stderr: {result.stderr_bytes.decode()!r}\n"
133                 f"exception: {result.exception}"
134             ),
135         )
136
137     @patch("black.dump_to_file", dump_to_stderr)
138     def test_empty(self) -> None:
139         source = expected = ""
140         actual = fs(source)
141         self.assertFormatEqual(expected, actual)
142         black.assert_equivalent(source, actual)
143         black.assert_stable(source, actual, DEFAULT_MODE)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_piping(self) -> None:
157         source, expected = read_data("src/black/__init__", data=False)
158         result = BlackRunner().invoke(
159             black.main,
160             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
161             input=BytesIO(source.encode("utf8")),
162         )
163         self.assertEqual(result.exit_code, 0)
164         self.assertFormatEqual(expected, result.output)
165         if source != result.output:
166             black.assert_equivalent(source, result.output)
167             black.assert_stable(source, result.output, DEFAULT_MODE)
168
169     def test_piping_diff(self) -> None:
170         diff_header = re.compile(
171             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
172             r"\+\d\d\d\d"
173         )
174         source, _ = read_data("expression.py")
175         expected, _ = read_data("expression.diff")
176         config = THIS_DIR / "data" / "empty_pyproject.toml"
177         args = [
178             "-",
179             "--fast",
180             f"--line-length={black.DEFAULT_LINE_LENGTH}",
181             "--diff",
182             f"--config={config}",
183         ]
184         result = BlackRunner().invoke(
185             black.main, args, input=BytesIO(source.encode("utf8"))
186         )
187         self.assertEqual(result.exit_code, 0)
188         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
189         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
190         self.assertEqual(expected, actual)
191
192     def test_piping_diff_with_color(self) -> None:
193         source, _ = read_data("expression.py")
194         config = THIS_DIR / "data" / "empty_pyproject.toml"
195         args = [
196             "-",
197             "--fast",
198             f"--line-length={black.DEFAULT_LINE_LENGTH}",
199             "--diff",
200             "--color",
201             f"--config={config}",
202         ]
203         result = BlackRunner().invoke(
204             black.main, args, input=BytesIO(source.encode("utf8"))
205         )
206         actual = result.output
207         # Again, the contents are checked in a different test, so only look for colors.
208         self.assertIn("\033[1;37m", actual)
209         self.assertIn("\033[36m", actual)
210         self.assertIn("\033[32m", actual)
211         self.assertIn("\033[31m", actual)
212         self.assertIn("\033[0m", actual)
213
214     @patch("black.dump_to_file", dump_to_stderr)
215     def _test_wip(self) -> None:
216         source, expected = read_data("wip")
217         sys.settrace(tracefunc)
218         mode = replace(
219             DEFAULT_MODE,
220             experimental_string_processing=False,
221             target_versions={black.TargetVersion.PY38},
222         )
223         actual = fs(source, mode=mode)
224         sys.settrace(None)
225         self.assertFormatEqual(expected, actual)
226         black.assert_equivalent(source, actual)
227         black.assert_stable(source, actual, black.FileMode())
228
229     @unittest.expectedFailure
230     @patch("black.dump_to_file", dump_to_stderr)
231     def test_trailing_comma_optional_parens_stability1(self) -> None:
232         source, _expected = read_data("trailing_comma_optional_parens1")
233         actual = fs(source)
234         black.assert_stable(source, actual, DEFAULT_MODE)
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability2(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens2")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability3(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens3")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
252         source, _expected = read_data("trailing_comma_optional_parens1")
253         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
254         black.assert_stable(source, actual, DEFAULT_MODE)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
258         source, _expected = read_data("trailing_comma_optional_parens2")
259         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
260         black.assert_stable(source, actual, DEFAULT_MODE)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
264         source, _expected = read_data("trailing_comma_optional_parens3")
265         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
266         black.assert_stable(source, actual, DEFAULT_MODE)
267
268     @patch("black.dump_to_file", dump_to_stderr)
269     def test_pep_572(self) -> None:
270         source, expected = read_data("pep_572")
271         actual = fs(source)
272         self.assertFormatEqual(expected, actual)
273         black.assert_stable(source, actual, DEFAULT_MODE)
274         if sys.version_info >= (3, 8):
275             black.assert_equivalent(source, actual)
276
277     @patch("black.dump_to_file", dump_to_stderr)
278     def test_pep_572_remove_parens(self) -> None:
279         source, expected = read_data("pep_572_remove_parens")
280         actual = fs(source)
281         self.assertFormatEqual(expected, actual)
282         black.assert_stable(source, actual, DEFAULT_MODE)
283         if sys.version_info >= (3, 8):
284             black.assert_equivalent(source, actual)
285
286     @patch("black.dump_to_file", dump_to_stderr)
287     def test_pep_572_do_not_remove_parens(self) -> None:
288         source, expected = read_data("pep_572_do_not_remove_parens")
289         # the AST safety checks will fail, but that's expected, just make sure no
290         # parentheses are touched
291         actual = black.format_str(source, mode=DEFAULT_MODE)
292         self.assertFormatEqual(expected, actual)
293
294     def test_pep_572_version_detection(self) -> None:
295         source, _ = read_data("pep_572")
296         root = black.lib2to3_parse(source)
297         features = black.get_features_used(root)
298         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
299         versions = black.detect_target_versions(root)
300         self.assertIn(black.TargetVersion.PY38, versions)
301
302     def test_expression_ff(self) -> None:
303         source, expected = read_data("expression")
304         tmp_file = Path(black.dump_to_file(source))
305         try:
306             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
307             with open(tmp_file, encoding="utf8") as f:
308                 actual = f.read()
309         finally:
310             os.unlink(tmp_file)
311         self.assertFormatEqual(expected, actual)
312         with patch("black.dump_to_file", dump_to_stderr):
313             black.assert_equivalent(source, actual)
314             black.assert_stable(source, actual, DEFAULT_MODE)
315
316     def test_expression_diff(self) -> None:
317         source, _ = read_data("expression.py")
318         config = THIS_DIR / "data" / "empty_pyproject.toml"
319         expected, _ = read_data("expression.diff")
320         tmp_file = Path(black.dump_to_file(source))
321         diff_header = re.compile(
322             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
323             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
324         )
325         try:
326             result = BlackRunner().invoke(
327                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
328             )
329             self.assertEqual(result.exit_code, 0)
330         finally:
331             os.unlink(tmp_file)
332         actual = result.output
333         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
334         if expected != actual:
335             dump = black.dump_to_file(actual)
336             msg = (
337                 "Expected diff isn't equal to the actual. If you made changes to"
338                 " expression.py and this is an anticipated difference, overwrite"
339                 f" tests/data/expression.diff with {dump}"
340             )
341             self.assertEqual(expected, actual, msg)
342
343     def test_expression_diff_with_color(self) -> None:
344         source, _ = read_data("expression.py")
345         config = THIS_DIR / "data" / "empty_pyproject.toml"
346         expected, _ = read_data("expression.diff")
347         tmp_file = Path(black.dump_to_file(source))
348         try:
349             result = BlackRunner().invoke(
350                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
351             )
352         finally:
353             os.unlink(tmp_file)
354         actual = result.output
355         # We check the contents of the diff in `test_expression_diff`. All
356         # we need to check here is that color codes exist in the result.
357         self.assertIn("\033[1;37m", actual)
358         self.assertIn("\033[36m", actual)
359         self.assertIn("\033[32m", actual)
360         self.assertIn("\033[31m", actual)
361         self.assertIn("\033[0m", actual)
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_pep_570(self) -> None:
365         source, expected = read_data("pep_570")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_stable(source, actual, DEFAULT_MODE)
369         if sys.version_info >= (3, 8):
370             black.assert_equivalent(source, actual)
371
372     def test_detect_pos_only_arguments(self) -> None:
373         source, _ = read_data("pep_570")
374         root = black.lib2to3_parse(source)
375         features = black.get_features_used(root)
376         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
377         versions = black.detect_target_versions(root)
378         self.assertIn(black.TargetVersion.PY38, versions)
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_string_quotes(self) -> None:
382         source, expected = read_data("string_quotes")
383         mode = black.Mode(experimental_string_processing=True)
384         actual = fs(source, mode=mode)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, mode)
388         mode = replace(mode, string_normalization=False)
389         not_normalized = fs(source, mode=mode)
390         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
391         black.assert_equivalent(source, not_normalized)
392         black.assert_stable(source, not_normalized, mode=mode)
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_docstring_no_string_normalization(self) -> None:
396         """Like test_docstring but with string normalization off."""
397         source, expected = read_data("docstring_no_string_normalization")
398         mode = replace(DEFAULT_MODE, string_normalization=False)
399         actual = fs(source, mode=mode)
400         self.assertFormatEqual(expected, actual)
401         black.assert_equivalent(source, actual)
402         black.assert_stable(source, actual, mode)
403
404     def test_long_strings_flag_disabled(self) -> None:
405         """Tests for turning off the string processing logic."""
406         source, expected = read_data("long_strings_flag_disabled")
407         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
408         actual = fs(source, mode=mode)
409         self.assertFormatEqual(expected, actual)
410         black.assert_stable(expected, actual, mode)
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_numeric_literals(self) -> None:
414         source, expected = read_data("numeric_literals")
415         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
416         actual = fs(source, mode=mode)
417         self.assertFormatEqual(expected, actual)
418         black.assert_equivalent(source, actual)
419         black.assert_stable(source, actual, mode)
420
421     @patch("black.dump_to_file", dump_to_stderr)
422     def test_numeric_literals_ignoring_underscores(self) -> None:
423         source, expected = read_data("numeric_literals_skip_underscores")
424         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
425         actual = fs(source, mode=mode)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, mode)
429
430     def test_skip_magic_trailing_comma(self) -> None:
431         source, _ = read_data("expression.py")
432         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
433         tmp_file = Path(black.dump_to_file(source))
434         diff_header = re.compile(
435             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
436             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
437         )
438         try:
439             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
440             self.assertEqual(result.exit_code, 0)
441         finally:
442             os.unlink(tmp_file)
443         actual = result.output
444         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
445         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
446         if expected != actual:
447             dump = black.dump_to_file(actual)
448             msg = (
449                 "Expected diff isn't equal to the actual. If you made changes to"
450                 " expression.py and this is an anticipated difference, overwrite"
451                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
452             )
453             self.assertEqual(expected, actual, msg)
454
455     @pytest.mark.python2
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2_print_function(self) -> None:
458         source, expected = read_data("python2_print_function")
459         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
460         actual = fs(source, mode=mode)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, mode)
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_stub(self) -> None:
467         mode = replace(DEFAULT_MODE, is_pyi=True)
468         source, expected = read_data("stub.pyi")
469         actual = fs(source, mode=mode)
470         self.assertFormatEqual(expected, actual)
471         black.assert_stable(source, actual, mode)
472
473     @patch("black.dump_to_file", dump_to_stderr)
474     def test_async_as_identifier(self) -> None:
475         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
476         source, expected = read_data("async_as_identifier")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         major, minor = sys.version_info[:2]
480         if major < 3 or (major <= 3 and minor < 7):
481             black.assert_equivalent(source, actual)
482         black.assert_stable(source, actual, DEFAULT_MODE)
483         # ensure black can parse this when the target is 3.6
484         self.invokeBlack([str(source_path), "--target-version", "py36"])
485         # but not on 3.7, because async/await is no longer an identifier
486         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_python37(self) -> None:
490         source_path = (THIS_DIR / "data" / "python37.py").resolve()
491         source, expected = read_data("python37")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         major, minor = sys.version_info[:2]
495         if major > 3 or (major == 3 and minor >= 7):
496             black.assert_equivalent(source, actual)
497         black.assert_stable(source, actual, DEFAULT_MODE)
498         # ensure black can parse this when the target is 3.7
499         self.invokeBlack([str(source_path), "--target-version", "py37"])
500         # but not on 3.6, because we use async as a reserved keyword
501         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
502
503     @patch("black.dump_to_file", dump_to_stderr)
504     def test_python38(self) -> None:
505         source, expected = read_data("python38")
506         actual = fs(source)
507         self.assertFormatEqual(expected, actual)
508         major, minor = sys.version_info[:2]
509         if major > 3 or (major == 3 and minor >= 8):
510             black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, DEFAULT_MODE)
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_python39(self) -> None:
515         source, expected = read_data("python39")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         major, minor = sys.version_info[:2]
519         if major > 3 or (major == 3 and minor >= 9):
520             black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, DEFAULT_MODE)
522
523     def test_tab_comment_indentation(self) -> None:
524         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
525         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
526         self.assertFormatEqual(contents_spc, fs(contents_spc))
527         self.assertFormatEqual(contents_spc, fs(contents_tab))
528
529         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
530         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
531         self.assertFormatEqual(contents_spc, fs(contents_spc))
532         self.assertFormatEqual(contents_spc, fs(contents_tab))
533
534         # mixed tabs and spaces (valid Python 2 code)
535         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
536         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
537         self.assertFormatEqual(contents_spc, fs(contents_spc))
538         self.assertFormatEqual(contents_spc, fs(contents_tab))
539
540         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
541         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
542         self.assertFormatEqual(contents_spc, fs(contents_spc))
543         self.assertFormatEqual(contents_spc, fs(contents_tab))
544
545     def test_report_verbose(self) -> None:
546         report = Report(verbose=True)
547         out_lines = []
548         err_lines = []
549
550         def out(msg: str, **kwargs: Any) -> None:
551             out_lines.append(msg)
552
553         def err(msg: str, **kwargs: Any) -> None:
554             err_lines.append(msg)
555
556         with patch("black.output._out", out), patch("black.output._err", err):
557             report.done(Path("f1"), black.Changed.NO)
558             self.assertEqual(len(out_lines), 1)
559             self.assertEqual(len(err_lines), 0)
560             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
561             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
562             self.assertEqual(report.return_code, 0)
563             report.done(Path("f2"), black.Changed.YES)
564             self.assertEqual(len(out_lines), 2)
565             self.assertEqual(len(err_lines), 0)
566             self.assertEqual(out_lines[-1], "reformatted f2")
567             self.assertEqual(
568                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
569             )
570             report.done(Path("f3"), black.Changed.CACHED)
571             self.assertEqual(len(out_lines), 3)
572             self.assertEqual(len(err_lines), 0)
573             self.assertEqual(
574                 out_lines[-1], "f3 wasn't modified on disk since last run."
575             )
576             self.assertEqual(
577                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
578             )
579             self.assertEqual(report.return_code, 0)
580             report.check = True
581             self.assertEqual(report.return_code, 1)
582             report.check = False
583             report.failed(Path("e1"), "boom")
584             self.assertEqual(len(out_lines), 3)
585             self.assertEqual(len(err_lines), 1)
586             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
587             self.assertEqual(
588                 unstyle(str(report)),
589                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
590                 " reformat.",
591             )
592             self.assertEqual(report.return_code, 123)
593             report.done(Path("f3"), black.Changed.YES)
594             self.assertEqual(len(out_lines), 4)
595             self.assertEqual(len(err_lines), 1)
596             self.assertEqual(out_lines[-1], "reformatted f3")
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
600                 " reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.failed(Path("e2"), "boom")
604             self.assertEqual(len(out_lines), 4)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
607             self.assertEqual(
608                 unstyle(str(report)),
609                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
610                 " reformat.",
611             )
612             self.assertEqual(report.return_code, 123)
613             report.path_ignored(Path("wat"), "no match")
614             self.assertEqual(len(out_lines), 5)
615             self.assertEqual(len(err_lines), 2)
616             self.assertEqual(out_lines[-1], "wat ignored: no match")
617             self.assertEqual(
618                 unstyle(str(report)),
619                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
620                 " reformat.",
621             )
622             self.assertEqual(report.return_code, 123)
623             report.done(Path("f4"), black.Changed.NO)
624             self.assertEqual(len(out_lines), 6)
625             self.assertEqual(len(err_lines), 2)
626             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
627             self.assertEqual(
628                 unstyle(str(report)),
629                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
630                 " reformat.",
631             )
632             self.assertEqual(report.return_code, 123)
633             report.check = True
634             self.assertEqual(
635                 unstyle(str(report)),
636                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
637                 " would fail to reformat.",
638             )
639             report.check = False
640             report.diff = True
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
644                 " would fail to reformat.",
645             )
646
647     def test_report_quiet(self) -> None:
648         report = Report(quiet=True)
649         out_lines = []
650         err_lines = []
651
652         def out(msg: str, **kwargs: Any) -> None:
653             out_lines.append(msg)
654
655         def err(msg: str, **kwargs: Any) -> None:
656             err_lines.append(msg)
657
658         with patch("black.output._out", out), patch("black.output._err", err):
659             report.done(Path("f1"), black.Changed.NO)
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 0)
662             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
663             self.assertEqual(report.return_code, 0)
664             report.done(Path("f2"), black.Changed.YES)
665             self.assertEqual(len(out_lines), 0)
666             self.assertEqual(len(err_lines), 0)
667             self.assertEqual(
668                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
669             )
670             report.done(Path("f3"), black.Changed.CACHED)
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 0)
673             self.assertEqual(
674                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
675             )
676             self.assertEqual(report.return_code, 0)
677             report.check = True
678             self.assertEqual(report.return_code, 1)
679             report.check = False
680             report.failed(Path("e1"), "boom")
681             self.assertEqual(len(out_lines), 0)
682             self.assertEqual(len(err_lines), 1)
683             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
684             self.assertEqual(
685                 unstyle(str(report)),
686                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
687                 " reformat.",
688             )
689             self.assertEqual(report.return_code, 123)
690             report.done(Path("f3"), black.Changed.YES)
691             self.assertEqual(len(out_lines), 0)
692             self.assertEqual(len(err_lines), 1)
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
696                 " reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.failed(Path("e2"), "boom")
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 2)
702             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
703             self.assertEqual(
704                 unstyle(str(report)),
705                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
706                 " reformat.",
707             )
708             self.assertEqual(report.return_code, 123)
709             report.path_ignored(Path("wat"), "no match")
710             self.assertEqual(len(out_lines), 0)
711             self.assertEqual(len(err_lines), 2)
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
715                 " reformat.",
716             )
717             self.assertEqual(report.return_code, 123)
718             report.done(Path("f4"), black.Changed.NO)
719             self.assertEqual(len(out_lines), 0)
720             self.assertEqual(len(err_lines), 2)
721             self.assertEqual(
722                 unstyle(str(report)),
723                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
724                 " reformat.",
725             )
726             self.assertEqual(report.return_code, 123)
727             report.check = True
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
731                 " would fail to reformat.",
732             )
733             report.check = False
734             report.diff = True
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
738                 " would fail to reformat.",
739             )
740
741     def test_report_normal(self) -> None:
742         report = black.Report()
743         out_lines = []
744         err_lines = []
745
746         def out(msg: str, **kwargs: Any) -> None:
747             out_lines.append(msg)
748
749         def err(msg: str, **kwargs: Any) -> None:
750             err_lines.append(msg)
751
752         with patch("black.output._out", out), patch("black.output._err", err):
753             report.done(Path("f1"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
757             self.assertEqual(report.return_code, 0)
758             report.done(Path("f2"), black.Changed.YES)
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 0)
761             self.assertEqual(out_lines[-1], "reformatted f2")
762             self.assertEqual(
763                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
764             )
765             report.done(Path("f3"), black.Changed.CACHED)
766             self.assertEqual(len(out_lines), 1)
767             self.assertEqual(len(err_lines), 0)
768             self.assertEqual(out_lines[-1], "reformatted f2")
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
771             )
772             self.assertEqual(report.return_code, 0)
773             report.check = True
774             self.assertEqual(report.return_code, 1)
775             report.check = False
776             report.failed(Path("e1"), "boom")
777             self.assertEqual(len(out_lines), 1)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
783                 " reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.done(Path("f3"), black.Changed.YES)
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 1)
789             self.assertEqual(out_lines[-1], "reformatted f3")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
793                 " reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.failed(Path("e2"), "boom")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
803                 " reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.path_ignored(Path("wat"), "no match")
807             self.assertEqual(len(out_lines), 2)
808             self.assertEqual(len(err_lines), 2)
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
812                 " reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.done(Path("f4"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
821                 " reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.check = True
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
828                 " would fail to reformat.",
829             )
830             report.check = False
831             report.diff = True
832             self.assertEqual(
833                 unstyle(str(report)),
834                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
835                 " would fail to reformat.",
836             )
837
838     def test_lib2to3_parse(self) -> None:
839         with self.assertRaises(black.InvalidInput):
840             black.lib2to3_parse("invalid syntax")
841
842         straddling = "x + y"
843         black.lib2to3_parse(straddling)
844         black.lib2to3_parse(straddling, {TargetVersion.PY27})
845         black.lib2to3_parse(straddling, {TargetVersion.PY36})
846         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
847
848         py2_only = "print x"
849         black.lib2to3_parse(py2_only)
850         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
851         with self.assertRaises(black.InvalidInput):
852             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
853         with self.assertRaises(black.InvalidInput):
854             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
855
856         py3_only = "exec(x, end=y)"
857         black.lib2to3_parse(py3_only)
858         with self.assertRaises(black.InvalidInput):
859             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
860         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
861         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
862
863     def test_get_features_used_decorator(self) -> None:
864         # Test the feature detection of new decorator syntax
865         # since this makes some test cases of test_get_features_used()
866         # fails if it fails, this is tested first so that a useful case
867         # is identified
868         simples, relaxed = read_data("decorators")
869         # skip explanation comments at the top of the file
870         for simple_test in simples.split("##")[1:]:
871             node = black.lib2to3_parse(simple_test)
872             decorator = str(node.children[0].children[0]).strip()
873             self.assertNotIn(
874                 Feature.RELAXED_DECORATORS,
875                 black.get_features_used(node),
876                 msg=(
877                     f"decorator '{decorator}' follows python<=3.8 syntax"
878                     "but is detected as 3.9+"
879                     # f"The full node is\n{node!r}"
880                 ),
881             )
882         # skip the '# output' comment at the top of the output part
883         for relaxed_test in relaxed.split("##")[1:]:
884             node = black.lib2to3_parse(relaxed_test)
885             decorator = str(node.children[0].children[0]).strip()
886             self.assertIn(
887                 Feature.RELAXED_DECORATORS,
888                 black.get_features_used(node),
889                 msg=(
890                     f"decorator '{decorator}' uses python3.9+ syntax"
891                     "but is detected as python<=3.8"
892                     # f"The full node is\n{node!r}"
893                 ),
894             )
895
896     def test_get_features_used(self) -> None:
897         node = black.lib2to3_parse("def f(*, arg): ...\n")
898         self.assertEqual(black.get_features_used(node), set())
899         node = black.lib2to3_parse("def f(*, arg,): ...\n")
900         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
901         node = black.lib2to3_parse("f(*arg,)\n")
902         self.assertEqual(
903             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
904         )
905         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
906         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
907         node = black.lib2to3_parse("123_456\n")
908         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
909         node = black.lib2to3_parse("123456\n")
910         self.assertEqual(black.get_features_used(node), set())
911         source, expected = read_data("function")
912         node = black.lib2to3_parse(source)
913         expected_features = {
914             Feature.TRAILING_COMMA_IN_CALL,
915             Feature.TRAILING_COMMA_IN_DEF,
916             Feature.F_STRINGS,
917         }
918         self.assertEqual(black.get_features_used(node), expected_features)
919         node = black.lib2to3_parse(expected)
920         self.assertEqual(black.get_features_used(node), expected_features)
921         source, expected = read_data("expression")
922         node = black.lib2to3_parse(source)
923         self.assertEqual(black.get_features_used(node), set())
924         node = black.lib2to3_parse(expected)
925         self.assertEqual(black.get_features_used(node), set())
926
927     def test_get_future_imports(self) -> None:
928         node = black.lib2to3_parse("\n")
929         self.assertEqual(set(), black.get_future_imports(node))
930         node = black.lib2to3_parse("from __future__ import black\n")
931         self.assertEqual({"black"}, black.get_future_imports(node))
932         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
933         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
934         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
935         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
936         node = black.lib2to3_parse(
937             "from __future__ import multiple\nfrom __future__ import imports\n"
938         )
939         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
940         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
941         self.assertEqual({"black"}, black.get_future_imports(node))
942         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
943         self.assertEqual({"black"}, black.get_future_imports(node))
944         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
945         self.assertEqual(set(), black.get_future_imports(node))
946         node = black.lib2to3_parse("from some.module import black\n")
947         self.assertEqual(set(), black.get_future_imports(node))
948         node = black.lib2to3_parse(
949             "from __future__ import unicode_literals as _unicode_literals"
950         )
951         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
952         node = black.lib2to3_parse(
953             "from __future__ import unicode_literals as _lol, print"
954         )
955         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
956
957     def test_debug_visitor(self) -> None:
958         source, _ = read_data("debug_visitor.py")
959         expected, _ = read_data("debug_visitor.out")
960         out_lines = []
961         err_lines = []
962
963         def out(msg: str, **kwargs: Any) -> None:
964             out_lines.append(msg)
965
966         def err(msg: str, **kwargs: Any) -> None:
967             err_lines.append(msg)
968
969         with patch("black.debug.out", out):
970             DebugVisitor.show(source)
971         actual = "\n".join(out_lines) + "\n"
972         log_name = ""
973         if expected != actual:
974             log_name = black.dump_to_file(*out_lines)
975         self.assertEqual(
976             expected,
977             actual,
978             f"AST print out is different. Actual version dumped to {log_name}",
979         )
980
981     def test_format_file_contents(self) -> None:
982         empty = ""
983         mode = DEFAULT_MODE
984         with self.assertRaises(black.NothingChanged):
985             black.format_file_contents(empty, mode=mode, fast=False)
986         just_nl = "\n"
987         with self.assertRaises(black.NothingChanged):
988             black.format_file_contents(just_nl, mode=mode, fast=False)
989         same = "j = [1, 2, 3]\n"
990         with self.assertRaises(black.NothingChanged):
991             black.format_file_contents(same, mode=mode, fast=False)
992         different = "j = [1,2,3]"
993         expected = same
994         actual = black.format_file_contents(different, mode=mode, fast=False)
995         self.assertEqual(expected, actual)
996         invalid = "return if you can"
997         with self.assertRaises(black.InvalidInput) as e:
998             black.format_file_contents(invalid, mode=mode, fast=False)
999         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1000
1001     def test_endmarker(self) -> None:
1002         n = black.lib2to3_parse("\n")
1003         self.assertEqual(n.type, black.syms.file_input)
1004         self.assertEqual(len(n.children), 1)
1005         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1006
1007     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1008     def test_assertFormatEqual(self) -> None:
1009         out_lines = []
1010         err_lines = []
1011
1012         def out(msg: str, **kwargs: Any) -> None:
1013             out_lines.append(msg)
1014
1015         def err(msg: str, **kwargs: Any) -> None:
1016             err_lines.append(msg)
1017
1018         with patch("black.output._out", out), patch("black.output._err", err):
1019             with self.assertRaises(AssertionError):
1020                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1021
1022         out_str = "".join(out_lines)
1023         self.assertTrue("Expected tree:" in out_str)
1024         self.assertTrue("Actual tree:" in out_str)
1025         self.assertEqual("".join(err_lines), "")
1026
1027     def test_cache_broken_file(self) -> None:
1028         mode = DEFAULT_MODE
1029         with cache_dir() as workspace:
1030             cache_file = get_cache_file(mode)
1031             with cache_file.open("w") as fobj:
1032                 fobj.write("this is not a pickle")
1033             self.assertEqual(black.read_cache(mode), {})
1034             src = (workspace / "test.py").resolve()
1035             with src.open("w") as fobj:
1036                 fobj.write("print('hello')")
1037             self.invokeBlack([str(src)])
1038             cache = black.read_cache(mode)
1039             self.assertIn(str(src), cache)
1040
1041     def test_cache_single_file_already_cached(self) -> None:
1042         mode = DEFAULT_MODE
1043         with cache_dir() as workspace:
1044             src = (workspace / "test.py").resolve()
1045             with src.open("w") as fobj:
1046                 fobj.write("print('hello')")
1047             black.write_cache({}, [src], mode)
1048             self.invokeBlack([str(src)])
1049             with src.open("r") as fobj:
1050                 self.assertEqual(fobj.read(), "print('hello')")
1051
1052     @event_loop()
1053     def test_cache_multiple_files(self) -> None:
1054         mode = DEFAULT_MODE
1055         with cache_dir() as workspace, patch(
1056             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1057         ):
1058             one = (workspace / "one.py").resolve()
1059             with one.open("w") as fobj:
1060                 fobj.write("print('hello')")
1061             two = (workspace / "two.py").resolve()
1062             with two.open("w") as fobj:
1063                 fobj.write("print('hello')")
1064             black.write_cache({}, [one], mode)
1065             self.invokeBlack([str(workspace)])
1066             with one.open("r") as fobj:
1067                 self.assertEqual(fobj.read(), "print('hello')")
1068             with two.open("r") as fobj:
1069                 self.assertEqual(fobj.read(), 'print("hello")\n')
1070             cache = black.read_cache(mode)
1071             self.assertIn(str(one), cache)
1072             self.assertIn(str(two), cache)
1073
1074     def test_no_cache_when_writeback_diff(self) -> None:
1075         mode = DEFAULT_MODE
1076         with cache_dir() as workspace:
1077             src = (workspace / "test.py").resolve()
1078             with src.open("w") as fobj:
1079                 fobj.write("print('hello')")
1080             with patch("black.read_cache") as read_cache, patch(
1081                 "black.write_cache"
1082             ) as write_cache:
1083                 self.invokeBlack([str(src), "--diff"])
1084                 cache_file = get_cache_file(mode)
1085                 self.assertFalse(cache_file.exists())
1086                 write_cache.assert_not_called()
1087                 read_cache.assert_not_called()
1088
1089     def test_no_cache_when_writeback_color_diff(self) -> None:
1090         mode = DEFAULT_MODE
1091         with cache_dir() as workspace:
1092             src = (workspace / "test.py").resolve()
1093             with src.open("w") as fobj:
1094                 fobj.write("print('hello')")
1095             with patch("black.read_cache") as read_cache, patch(
1096                 "black.write_cache"
1097             ) as write_cache:
1098                 self.invokeBlack([str(src), "--diff", "--color"])
1099                 cache_file = get_cache_file(mode)
1100                 self.assertFalse(cache_file.exists())
1101                 write_cache.assert_not_called()
1102                 read_cache.assert_not_called()
1103
1104     @event_loop()
1105     def test_output_locking_when_writeback_diff(self) -> None:
1106         with cache_dir() as workspace:
1107             for tag in range(0, 4):
1108                 src = (workspace / f"test{tag}.py").resolve()
1109                 with src.open("w") as fobj:
1110                     fobj.write("print('hello')")
1111             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1112                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1113                 # this isn't quite doing what we want, but if it _isn't_
1114                 # called then we cannot be using the lock it provides
1115                 mgr.assert_called()
1116
1117     @event_loop()
1118     def test_output_locking_when_writeback_color_diff(self) -> None:
1119         with cache_dir() as workspace:
1120             for tag in range(0, 4):
1121                 src = (workspace / f"test{tag}.py").resolve()
1122                 with src.open("w") as fobj:
1123                     fobj.write("print('hello')")
1124             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1125                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1126                 # this isn't quite doing what we want, but if it _isn't_
1127                 # called then we cannot be using the lock it provides
1128                 mgr.assert_called()
1129
1130     def test_no_cache_when_stdin(self) -> None:
1131         mode = DEFAULT_MODE
1132         with cache_dir():
1133             result = CliRunner().invoke(
1134                 black.main, ["-"], input=BytesIO(b"print('hello')")
1135             )
1136             self.assertEqual(result.exit_code, 0)
1137             cache_file = get_cache_file(mode)
1138             self.assertFalse(cache_file.exists())
1139
1140     def test_read_cache_no_cachefile(self) -> None:
1141         mode = DEFAULT_MODE
1142         with cache_dir():
1143             self.assertEqual(black.read_cache(mode), {})
1144
1145     def test_write_cache_read_cache(self) -> None:
1146         mode = DEFAULT_MODE
1147         with cache_dir() as workspace:
1148             src = (workspace / "test.py").resolve()
1149             src.touch()
1150             black.write_cache({}, [src], mode)
1151             cache = black.read_cache(mode)
1152             self.assertIn(str(src), cache)
1153             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1154
1155     def test_filter_cached(self) -> None:
1156         with TemporaryDirectory() as workspace:
1157             path = Path(workspace)
1158             uncached = (path / "uncached").resolve()
1159             cached = (path / "cached").resolve()
1160             cached_but_changed = (path / "changed").resolve()
1161             uncached.touch()
1162             cached.touch()
1163             cached_but_changed.touch()
1164             cache = {
1165                 str(cached): black.get_cache_info(cached),
1166                 str(cached_but_changed): (0.0, 0),
1167             }
1168             todo, done = black.filter_cached(
1169                 cache, {uncached, cached, cached_but_changed}
1170             )
1171             self.assertEqual(todo, {uncached, cached_but_changed})
1172             self.assertEqual(done, {cached})
1173
1174     def test_write_cache_creates_directory_if_needed(self) -> None:
1175         mode = DEFAULT_MODE
1176         with cache_dir(exists=False) as workspace:
1177             self.assertFalse(workspace.exists())
1178             black.write_cache({}, [], mode)
1179             self.assertTrue(workspace.exists())
1180
1181     @event_loop()
1182     def test_failed_formatting_does_not_get_cached(self) -> None:
1183         mode = DEFAULT_MODE
1184         with cache_dir() as workspace, patch(
1185             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1186         ):
1187             failing = (workspace / "failing.py").resolve()
1188             with failing.open("w") as fobj:
1189                 fobj.write("not actually python")
1190             clean = (workspace / "clean.py").resolve()
1191             with clean.open("w") as fobj:
1192                 fobj.write('print("hello")\n')
1193             self.invokeBlack([str(workspace)], exit_code=123)
1194             cache = black.read_cache(mode)
1195             self.assertNotIn(str(failing), cache)
1196             self.assertIn(str(clean), cache)
1197
1198     def test_write_cache_write_fail(self) -> None:
1199         mode = DEFAULT_MODE
1200         with cache_dir(), patch.object(Path, "open") as mock:
1201             mock.side_effect = OSError
1202             black.write_cache({}, [], mode)
1203
1204     @event_loop()
1205     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1206     def test_works_in_mono_process_only_environment(self) -> None:
1207         with cache_dir() as workspace:
1208             for f in [
1209                 (workspace / "one.py").resolve(),
1210                 (workspace / "two.py").resolve(),
1211             ]:
1212                 f.write_text('print("hello")\n')
1213             self.invokeBlack([str(workspace)])
1214
1215     @event_loop()
1216     def test_check_diff_use_together(self) -> None:
1217         with cache_dir():
1218             # Files which will be reformatted.
1219             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1220             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1221             # Files which will not be reformatted.
1222             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1223             self.invokeBlack([str(src2), "--diff", "--check"])
1224             # Multi file command.
1225             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1226
1227     def test_no_files(self) -> None:
1228         with cache_dir():
1229             # Without an argument, black exits with error code 0.
1230             self.invokeBlack([])
1231
1232     def test_broken_symlink(self) -> None:
1233         with cache_dir() as workspace:
1234             symlink = workspace / "broken_link.py"
1235             try:
1236                 symlink.symlink_to("nonexistent.py")
1237             except OSError as e:
1238                 self.skipTest(f"Can't create symlinks: {e}")
1239             self.invokeBlack([str(workspace.resolve())])
1240
1241     def test_read_cache_line_lengths(self) -> None:
1242         mode = DEFAULT_MODE
1243         short_mode = replace(DEFAULT_MODE, line_length=1)
1244         with cache_dir() as workspace:
1245             path = (workspace / "file.py").resolve()
1246             path.touch()
1247             black.write_cache({}, [path], mode)
1248             one = black.read_cache(mode)
1249             self.assertIn(str(path), one)
1250             two = black.read_cache(short_mode)
1251             self.assertNotIn(str(path), two)
1252
1253     def test_single_file_force_pyi(self) -> None:
1254         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1255         contents, expected = read_data("force_pyi")
1256         with cache_dir() as workspace:
1257             path = (workspace / "file.py").resolve()
1258             with open(path, "w") as fh:
1259                 fh.write(contents)
1260             self.invokeBlack([str(path), "--pyi"])
1261             with open(path, "r") as fh:
1262                 actual = fh.read()
1263             # verify cache with --pyi is separate
1264             pyi_cache = black.read_cache(pyi_mode)
1265             self.assertIn(str(path), pyi_cache)
1266             normal_cache = black.read_cache(DEFAULT_MODE)
1267             self.assertNotIn(str(path), normal_cache)
1268         self.assertFormatEqual(expected, actual)
1269         black.assert_equivalent(contents, actual)
1270         black.assert_stable(contents, actual, pyi_mode)
1271
1272     @event_loop()
1273     def test_multi_file_force_pyi(self) -> None:
1274         reg_mode = DEFAULT_MODE
1275         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1276         contents, expected = read_data("force_pyi")
1277         with cache_dir() as workspace:
1278             paths = [
1279                 (workspace / "file1.py").resolve(),
1280                 (workspace / "file2.py").resolve(),
1281             ]
1282             for path in paths:
1283                 with open(path, "w") as fh:
1284                     fh.write(contents)
1285             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1286             for path in paths:
1287                 with open(path, "r") as fh:
1288                     actual = fh.read()
1289                 self.assertEqual(actual, expected)
1290             # verify cache with --pyi is separate
1291             pyi_cache = black.read_cache(pyi_mode)
1292             normal_cache = black.read_cache(reg_mode)
1293             for path in paths:
1294                 self.assertIn(str(path), pyi_cache)
1295                 self.assertNotIn(str(path), normal_cache)
1296
1297     def test_pipe_force_pyi(self) -> None:
1298         source, expected = read_data("force_pyi")
1299         result = CliRunner().invoke(
1300             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1301         )
1302         self.assertEqual(result.exit_code, 0)
1303         actual = result.output
1304         self.assertFormatEqual(actual, expected)
1305
1306     def test_single_file_force_py36(self) -> None:
1307         reg_mode = DEFAULT_MODE
1308         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1309         source, expected = read_data("force_py36")
1310         with cache_dir() as workspace:
1311             path = (workspace / "file.py").resolve()
1312             with open(path, "w") as fh:
1313                 fh.write(source)
1314             self.invokeBlack([str(path), *PY36_ARGS])
1315             with open(path, "r") as fh:
1316                 actual = fh.read()
1317             # verify cache with --target-version is separate
1318             py36_cache = black.read_cache(py36_mode)
1319             self.assertIn(str(path), py36_cache)
1320             normal_cache = black.read_cache(reg_mode)
1321             self.assertNotIn(str(path), normal_cache)
1322         self.assertEqual(actual, expected)
1323
1324     @event_loop()
1325     def test_multi_file_force_py36(self) -> None:
1326         reg_mode = DEFAULT_MODE
1327         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1328         source, expected = read_data("force_py36")
1329         with cache_dir() as workspace:
1330             paths = [
1331                 (workspace / "file1.py").resolve(),
1332                 (workspace / "file2.py").resolve(),
1333             ]
1334             for path in paths:
1335                 with open(path, "w") as fh:
1336                     fh.write(source)
1337             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1338             for path in paths:
1339                 with open(path, "r") as fh:
1340                     actual = fh.read()
1341                 self.assertEqual(actual, expected)
1342             # verify cache with --target-version is separate
1343             pyi_cache = black.read_cache(py36_mode)
1344             normal_cache = black.read_cache(reg_mode)
1345             for path in paths:
1346                 self.assertIn(str(path), pyi_cache)
1347                 self.assertNotIn(str(path), normal_cache)
1348
1349     def test_pipe_force_py36(self) -> None:
1350         source, expected = read_data("force_py36")
1351         result = CliRunner().invoke(
1352             black.main,
1353             ["-", "-q", "--target-version=py36"],
1354             input=BytesIO(source.encode("utf8")),
1355         )
1356         self.assertEqual(result.exit_code, 0)
1357         actual = result.output
1358         self.assertFormatEqual(actual, expected)
1359
1360     def test_include_exclude(self) -> None:
1361         path = THIS_DIR / "data" / "include_exclude_tests"
1362         include = re.compile(r"\.pyi?$")
1363         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1364         report = black.Report()
1365         gitignore = PathSpec.from_lines("gitwildmatch", [])
1366         sources: List[Path] = []
1367         expected = [
1368             Path(path / "b/dont_exclude/a.py"),
1369             Path(path / "b/dont_exclude/a.pyi"),
1370         ]
1371         this_abs = THIS_DIR.resolve()
1372         sources.extend(
1373             black.gen_python_files(
1374                 path.iterdir(),
1375                 this_abs,
1376                 include,
1377                 exclude,
1378                 None,
1379                 None,
1380                 report,
1381                 gitignore,
1382             )
1383         )
1384         self.assertEqual(sorted(expected), sorted(sources))
1385
1386     def test_gitignore_used_as_default(self) -> None:
1387         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1388         include = re.compile(r"\.pyi?$")
1389         extend_exclude = re.compile(r"/exclude/")
1390         src = str(path / "b/")
1391         report = black.Report()
1392         expected: List[Path] = [
1393             path / "b/.definitely_exclude/a.py",
1394             path / "b/.definitely_exclude/a.pyi",
1395         ]
1396         sources = list(
1397             black.get_sources(
1398                 ctx=FakeContext(),
1399                 src=(src,),
1400                 quiet=True,
1401                 verbose=False,
1402                 include=include,
1403                 exclude=None,
1404                 extend_exclude=extend_exclude,
1405                 force_exclude=None,
1406                 report=report,
1407                 stdin_filename=None,
1408             )
1409         )
1410         self.assertEqual(sorted(expected), sorted(sources))
1411
1412     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1413     def test_exclude_for_issue_1572(self) -> None:
1414         # Exclude shouldn't touch files that were explicitly given to Black through the
1415         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1416         # https://github.com/psf/black/issues/1572
1417         path = THIS_DIR / "data" / "include_exclude_tests"
1418         include = ""
1419         exclude = r"/exclude/|a\.py"
1420         src = str(path / "b/exclude/a.py")
1421         report = black.Report()
1422         expected = [Path(path / "b/exclude/a.py")]
1423         sources = list(
1424             black.get_sources(
1425                 ctx=FakeContext(),
1426                 src=(src,),
1427                 quiet=True,
1428                 verbose=False,
1429                 include=re.compile(include),
1430                 exclude=re.compile(exclude),
1431                 extend_exclude=None,
1432                 force_exclude=None,
1433                 report=report,
1434                 stdin_filename=None,
1435             )
1436         )
1437         self.assertEqual(sorted(expected), sorted(sources))
1438
1439     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1440     def test_get_sources_with_stdin(self) -> None:
1441         include = ""
1442         exclude = r"/exclude/|a\.py"
1443         src = "-"
1444         report = black.Report()
1445         expected = [Path("-")]
1446         sources = list(
1447             black.get_sources(
1448                 ctx=FakeContext(),
1449                 src=(src,),
1450                 quiet=True,
1451                 verbose=False,
1452                 include=re.compile(include),
1453                 exclude=re.compile(exclude),
1454                 extend_exclude=None,
1455                 force_exclude=None,
1456                 report=report,
1457                 stdin_filename=None,
1458             )
1459         )
1460         self.assertEqual(sorted(expected), sorted(sources))
1461
1462     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1463     def test_get_sources_with_stdin_filename(self) -> None:
1464         include = ""
1465         exclude = r"/exclude/|a\.py"
1466         src = "-"
1467         report = black.Report()
1468         stdin_filename = str(THIS_DIR / "data/collections.py")
1469         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1470         sources = list(
1471             black.get_sources(
1472                 ctx=FakeContext(),
1473                 src=(src,),
1474                 quiet=True,
1475                 verbose=False,
1476                 include=re.compile(include),
1477                 exclude=re.compile(exclude),
1478                 extend_exclude=None,
1479                 force_exclude=None,
1480                 report=report,
1481                 stdin_filename=stdin_filename,
1482             )
1483         )
1484         self.assertEqual(sorted(expected), sorted(sources))
1485
1486     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1487     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1488         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1489         # file being passed directly. This is the same as
1490         # test_exclude_for_issue_1572
1491         path = THIS_DIR / "data" / "include_exclude_tests"
1492         include = ""
1493         exclude = r"/exclude/|a\.py"
1494         src = "-"
1495         report = black.Report()
1496         stdin_filename = str(path / "b/exclude/a.py")
1497         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1498         sources = list(
1499             black.get_sources(
1500                 ctx=FakeContext(),
1501                 src=(src,),
1502                 quiet=True,
1503                 verbose=False,
1504                 include=re.compile(include),
1505                 exclude=re.compile(exclude),
1506                 extend_exclude=None,
1507                 force_exclude=None,
1508                 report=report,
1509                 stdin_filename=stdin_filename,
1510             )
1511         )
1512         self.assertEqual(sorted(expected), sorted(sources))
1513
1514     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1515     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1516         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1517         # file being passed directly. This is the same as
1518         # test_exclude_for_issue_1572
1519         path = THIS_DIR / "data" / "include_exclude_tests"
1520         include = ""
1521         extend_exclude = r"/exclude/|a\.py"
1522         src = "-"
1523         report = black.Report()
1524         stdin_filename = str(path / "b/exclude/a.py")
1525         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1526         sources = list(
1527             black.get_sources(
1528                 ctx=FakeContext(),
1529                 src=(src,),
1530                 quiet=True,
1531                 verbose=False,
1532                 include=re.compile(include),
1533                 exclude=re.compile(""),
1534                 extend_exclude=re.compile(extend_exclude),
1535                 force_exclude=None,
1536                 report=report,
1537                 stdin_filename=stdin_filename,
1538             )
1539         )
1540         self.assertEqual(sorted(expected), sorted(sources))
1541
1542     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1543     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1544         # Force exclude should exclude the file when passing it through
1545         # stdin_filename
1546         path = THIS_DIR / "data" / "include_exclude_tests"
1547         include = ""
1548         force_exclude = r"/exclude/|a\.py"
1549         src = "-"
1550         report = black.Report()
1551         stdin_filename = str(path / "b/exclude/a.py")
1552         sources = list(
1553             black.get_sources(
1554                 ctx=FakeContext(),
1555                 src=(src,),
1556                 quiet=True,
1557                 verbose=False,
1558                 include=re.compile(include),
1559                 exclude=re.compile(""),
1560                 extend_exclude=None,
1561                 force_exclude=re.compile(force_exclude),
1562                 report=report,
1563                 stdin_filename=stdin_filename,
1564             )
1565         )
1566         self.assertEqual([], sorted(sources))
1567
1568     def test_reformat_one_with_stdin(self) -> None:
1569         with patch(
1570             "black.format_stdin_to_stdout",
1571             return_value=lambda *args, **kwargs: black.Changed.YES,
1572         ) as fsts:
1573             report = MagicMock()
1574             path = Path("-")
1575             black.reformat_one(
1576                 path,
1577                 fast=True,
1578                 write_back=black.WriteBack.YES,
1579                 mode=DEFAULT_MODE,
1580                 report=report,
1581             )
1582             fsts.assert_called_once()
1583             report.done.assert_called_with(path, black.Changed.YES)
1584
1585     def test_reformat_one_with_stdin_filename(self) -> None:
1586         with patch(
1587             "black.format_stdin_to_stdout",
1588             return_value=lambda *args, **kwargs: black.Changed.YES,
1589         ) as fsts:
1590             report = MagicMock()
1591             p = "foo.py"
1592             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1593             expected = Path(p)
1594             black.reformat_one(
1595                 path,
1596                 fast=True,
1597                 write_back=black.WriteBack.YES,
1598                 mode=DEFAULT_MODE,
1599                 report=report,
1600             )
1601             fsts.assert_called_once_with(
1602                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1603             )
1604             # __BLACK_STDIN_FILENAME__ should have been stripped
1605             report.done.assert_called_with(expected, black.Changed.YES)
1606
1607     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1608         with patch(
1609             "black.format_stdin_to_stdout",
1610             return_value=lambda *args, **kwargs: black.Changed.YES,
1611         ) as fsts:
1612             report = MagicMock()
1613             p = "foo.pyi"
1614             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1615             expected = Path(p)
1616             black.reformat_one(
1617                 path,
1618                 fast=True,
1619                 write_back=black.WriteBack.YES,
1620                 mode=DEFAULT_MODE,
1621                 report=report,
1622             )
1623             fsts.assert_called_once_with(
1624                 fast=True,
1625                 write_back=black.WriteBack.YES,
1626                 mode=replace(DEFAULT_MODE, is_pyi=True),
1627             )
1628             # __BLACK_STDIN_FILENAME__ should have been stripped
1629             report.done.assert_called_with(expected, black.Changed.YES)
1630
1631     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1632         with patch(
1633             "black.format_stdin_to_stdout",
1634             return_value=lambda *args, **kwargs: black.Changed.YES,
1635         ) as fsts:
1636             report = MagicMock()
1637             # Even with an existing file, since we are forcing stdin, black
1638             # should output to stdout and not modify the file inplace
1639             p = Path(str(THIS_DIR / "data/collections.py"))
1640             # Make sure is_file actually returns True
1641             self.assertTrue(p.is_file())
1642             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1643             expected = Path(p)
1644             black.reformat_one(
1645                 path,
1646                 fast=True,
1647                 write_back=black.WriteBack.YES,
1648                 mode=DEFAULT_MODE,
1649                 report=report,
1650             )
1651             fsts.assert_called_once()
1652             # __BLACK_STDIN_FILENAME__ should have been stripped
1653             report.done.assert_called_with(expected, black.Changed.YES)
1654
1655     def test_reformat_one_with_stdin_empty(self) -> None:
1656         output = io.StringIO()
1657         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1658             try:
1659                 black.format_stdin_to_stdout(
1660                     fast=True,
1661                     content="",
1662                     write_back=black.WriteBack.YES,
1663                     mode=DEFAULT_MODE,
1664                 )
1665             except io.UnsupportedOperation:
1666                 pass  # StringIO does not support detach
1667             assert output.getvalue() == ""
1668
1669     def test_gitignore_exclude(self) -> None:
1670         path = THIS_DIR / "data" / "include_exclude_tests"
1671         include = re.compile(r"\.pyi?$")
1672         exclude = re.compile(r"")
1673         report = black.Report()
1674         gitignore = PathSpec.from_lines(
1675             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1676         )
1677         sources: List[Path] = []
1678         expected = [
1679             Path(path / "b/dont_exclude/a.py"),
1680             Path(path / "b/dont_exclude/a.pyi"),
1681         ]
1682         this_abs = THIS_DIR.resolve()
1683         sources.extend(
1684             black.gen_python_files(
1685                 path.iterdir(),
1686                 this_abs,
1687                 include,
1688                 exclude,
1689                 None,
1690                 None,
1691                 report,
1692                 gitignore,
1693             )
1694         )
1695         self.assertEqual(sorted(expected), sorted(sources))
1696
1697     def test_nested_gitignore(self) -> None:
1698         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1699         include = re.compile(r"\.pyi?$")
1700         exclude = re.compile(r"")
1701         root_gitignore = black.files.get_gitignore(path)
1702         report = black.Report()
1703         expected: List[Path] = [
1704             Path(path / "x.py"),
1705             Path(path / "root/b.py"),
1706             Path(path / "root/c.py"),
1707             Path(path / "root/child/c.py"),
1708         ]
1709         this_abs = THIS_DIR.resolve()
1710         sources = list(
1711             black.gen_python_files(
1712                 path.iterdir(),
1713                 this_abs,
1714                 include,
1715                 exclude,
1716                 None,
1717                 None,
1718                 report,
1719                 root_gitignore,
1720             )
1721         )
1722         self.assertEqual(sorted(expected), sorted(sources))
1723
1724     def test_empty_include(self) -> None:
1725         path = THIS_DIR / "data" / "include_exclude_tests"
1726         report = black.Report()
1727         gitignore = PathSpec.from_lines("gitwildmatch", [])
1728         empty = re.compile(r"")
1729         sources: List[Path] = []
1730         expected = [
1731             Path(path / "b/exclude/a.pie"),
1732             Path(path / "b/exclude/a.py"),
1733             Path(path / "b/exclude/a.pyi"),
1734             Path(path / "b/dont_exclude/a.pie"),
1735             Path(path / "b/dont_exclude/a.py"),
1736             Path(path / "b/dont_exclude/a.pyi"),
1737             Path(path / "b/.definitely_exclude/a.pie"),
1738             Path(path / "b/.definitely_exclude/a.py"),
1739             Path(path / "b/.definitely_exclude/a.pyi"),
1740             Path(path / ".gitignore"),
1741             Path(path / "pyproject.toml"),
1742         ]
1743         this_abs = THIS_DIR.resolve()
1744         sources.extend(
1745             black.gen_python_files(
1746                 path.iterdir(),
1747                 this_abs,
1748                 empty,
1749                 re.compile(black.DEFAULT_EXCLUDES),
1750                 None,
1751                 None,
1752                 report,
1753                 gitignore,
1754             )
1755         )
1756         self.assertEqual(sorted(expected), sorted(sources))
1757
1758     def test_extend_exclude(self) -> None:
1759         path = THIS_DIR / "data" / "include_exclude_tests"
1760         report = black.Report()
1761         gitignore = PathSpec.from_lines("gitwildmatch", [])
1762         sources: List[Path] = []
1763         expected = [
1764             Path(path / "b/exclude/a.py"),
1765             Path(path / "b/dont_exclude/a.py"),
1766         ]
1767         this_abs = THIS_DIR.resolve()
1768         sources.extend(
1769             black.gen_python_files(
1770                 path.iterdir(),
1771                 this_abs,
1772                 re.compile(black.DEFAULT_INCLUDES),
1773                 re.compile(r"\.pyi$"),
1774                 re.compile(r"\.definitely_exclude"),
1775                 None,
1776                 report,
1777                 gitignore,
1778             )
1779         )
1780         self.assertEqual(sorted(expected), sorted(sources))
1781
1782     def test_invalid_cli_regex(self) -> None:
1783         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1784             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1785
1786     def test_required_version_matches_version(self) -> None:
1787         self.invokeBlack(
1788             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1789         )
1790
1791     def test_required_version_does_not_match_version(self) -> None:
1792         self.invokeBlack(
1793             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1794         )
1795
1796     def test_preserves_line_endings(self) -> None:
1797         with TemporaryDirectory() as workspace:
1798             test_file = Path(workspace) / "test.py"
1799             for nl in ["\n", "\r\n"]:
1800                 contents = nl.join(["def f(  ):", "    pass"])
1801                 test_file.write_bytes(contents.encode())
1802                 ff(test_file, write_back=black.WriteBack.YES)
1803                 updated_contents: bytes = test_file.read_bytes()
1804                 self.assertIn(nl.encode(), updated_contents)
1805                 if nl == "\n":
1806                     self.assertNotIn(b"\r\n", updated_contents)
1807
1808     def test_preserves_line_endings_via_stdin(self) -> None:
1809         for nl in ["\n", "\r\n"]:
1810             contents = nl.join(["def f(  ):", "    pass"])
1811             runner = BlackRunner()
1812             result = runner.invoke(
1813                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1814             )
1815             self.assertEqual(result.exit_code, 0)
1816             output = result.stdout_bytes
1817             self.assertIn(nl.encode("utf8"), output)
1818             if nl == "\n":
1819                 self.assertNotIn(b"\r\n", output)
1820
1821     def test_assert_equivalent_different_asts(self) -> None:
1822         with self.assertRaises(AssertionError):
1823             black.assert_equivalent("{}", "None")
1824
1825     def test_symlink_out_of_root_directory(self) -> None:
1826         path = MagicMock()
1827         root = THIS_DIR.resolve()
1828         child = MagicMock()
1829         include = re.compile(black.DEFAULT_INCLUDES)
1830         exclude = re.compile(black.DEFAULT_EXCLUDES)
1831         report = black.Report()
1832         gitignore = PathSpec.from_lines("gitwildmatch", [])
1833         # `child` should behave like a symlink which resolved path is clearly
1834         # outside of the `root` directory.
1835         path.iterdir.return_value = [child]
1836         child.resolve.return_value = Path("/a/b/c")
1837         child.as_posix.return_value = "/a/b/c"
1838         child.is_symlink.return_value = True
1839         try:
1840             list(
1841                 black.gen_python_files(
1842                     path.iterdir(),
1843                     root,
1844                     include,
1845                     exclude,
1846                     None,
1847                     None,
1848                     report,
1849                     gitignore,
1850                 )
1851             )
1852         except ValueError as ve:
1853             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1854         path.iterdir.assert_called_once()
1855         child.resolve.assert_called_once()
1856         child.is_symlink.assert_called_once()
1857         # `child` should behave like a strange file which resolved path is clearly
1858         # outside of the `root` directory.
1859         child.is_symlink.return_value = False
1860         with self.assertRaises(ValueError):
1861             list(
1862                 black.gen_python_files(
1863                     path.iterdir(),
1864                     root,
1865                     include,
1866                     exclude,
1867                     None,
1868                     None,
1869                     report,
1870                     gitignore,
1871                 )
1872             )
1873         path.iterdir.assert_called()
1874         self.assertEqual(path.iterdir.call_count, 2)
1875         child.resolve.assert_called()
1876         self.assertEqual(child.resolve.call_count, 2)
1877         child.is_symlink.assert_called()
1878         self.assertEqual(child.is_symlink.call_count, 2)
1879
1880     def test_shhh_click(self) -> None:
1881         try:
1882             from click import _unicodefun
1883         except ModuleNotFoundError:
1884             self.skipTest("Incompatible Click version")
1885         if not hasattr(_unicodefun, "_verify_python3_env"):
1886             self.skipTest("Incompatible Click version")
1887         # First, let's see if Click is crashing with a preferred ASCII charset.
1888         with patch("locale.getpreferredencoding") as gpe:
1889             gpe.return_value = "ASCII"
1890             with self.assertRaises(RuntimeError):
1891                 _unicodefun._verify_python3_env()  # type: ignore
1892         # Now, let's silence Click...
1893         black.patch_click()
1894         # ...and confirm it's silent.
1895         with patch("locale.getpreferredencoding") as gpe:
1896             gpe.return_value = "ASCII"
1897             try:
1898                 _unicodefun._verify_python3_env()  # type: ignore
1899             except RuntimeError as re:
1900                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1901
1902     def test_root_logger_not_used_directly(self) -> None:
1903         def fail(*args: Any, **kwargs: Any) -> None:
1904             self.fail("Record created with root logger")
1905
1906         with patch.multiple(
1907             logging.root,
1908             debug=fail,
1909             info=fail,
1910             warning=fail,
1911             error=fail,
1912             critical=fail,
1913             log=fail,
1914         ):
1915             ff(THIS_DIR / "util.py")
1916
1917     def test_invalid_config_return_code(self) -> None:
1918         tmp_file = Path(black.dump_to_file())
1919         try:
1920             tmp_config = Path(black.dump_to_file())
1921             tmp_config.unlink()
1922             args = ["--config", str(tmp_config), str(tmp_file)]
1923             self.invokeBlack(args, exit_code=2, ignore_config=False)
1924         finally:
1925             tmp_file.unlink()
1926
1927     def test_parse_pyproject_toml(self) -> None:
1928         test_toml_file = THIS_DIR / "test.toml"
1929         config = black.parse_pyproject_toml(str(test_toml_file))
1930         self.assertEqual(config["verbose"], 1)
1931         self.assertEqual(config["check"], "no")
1932         self.assertEqual(config["diff"], "y")
1933         self.assertEqual(config["color"], True)
1934         self.assertEqual(config["line_length"], 79)
1935         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1936         self.assertEqual(config["exclude"], r"\.pyi?$")
1937         self.assertEqual(config["include"], r"\.py?$")
1938
1939     def test_read_pyproject_toml(self) -> None:
1940         test_toml_file = THIS_DIR / "test.toml"
1941         fake_ctx = FakeContext()
1942         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1943         config = fake_ctx.default_map
1944         self.assertEqual(config["verbose"], "1")
1945         self.assertEqual(config["check"], "no")
1946         self.assertEqual(config["diff"], "y")
1947         self.assertEqual(config["color"], "True")
1948         self.assertEqual(config["line_length"], "79")
1949         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1950         self.assertEqual(config["exclude"], r"\.pyi?$")
1951         self.assertEqual(config["include"], r"\.py?$")
1952
1953     def test_find_project_root(self) -> None:
1954         with TemporaryDirectory() as workspace:
1955             root = Path(workspace)
1956             test_dir = root / "test"
1957             test_dir.mkdir()
1958
1959             src_dir = root / "src"
1960             src_dir.mkdir()
1961
1962             root_pyproject = root / "pyproject.toml"
1963             root_pyproject.touch()
1964             src_pyproject = src_dir / "pyproject.toml"
1965             src_pyproject.touch()
1966             src_python = src_dir / "foo.py"
1967             src_python.touch()
1968
1969             self.assertEqual(
1970                 black.find_project_root((src_dir, test_dir)), root.resolve()
1971             )
1972             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1973             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1974
1975     @patch(
1976         "black.files.find_user_pyproject_toml",
1977         black.files.find_user_pyproject_toml.__wrapped__,
1978     )
1979     def test_find_user_pyproject_toml_linux(self) -> None:
1980         if system() == "Windows":
1981             return
1982
1983         # Test if XDG_CONFIG_HOME is checked
1984         with TemporaryDirectory() as workspace:
1985             tmp_user_config = Path(workspace) / "black"
1986             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1987                 self.assertEqual(
1988                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1989                 )
1990
1991         # Test fallback for XDG_CONFIG_HOME
1992         with patch.dict("os.environ"):
1993             os.environ.pop("XDG_CONFIG_HOME", None)
1994             fallback_user_config = Path("~/.config").expanduser() / "black"
1995             self.assertEqual(
1996                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1997             )
1998
1999     def test_find_user_pyproject_toml_windows(self) -> None:
2000         if system() != "Windows":
2001             return
2002
2003         user_config_path = Path.home() / ".black"
2004         self.assertEqual(
2005             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2006         )
2007
2008     def test_bpo_33660_workaround(self) -> None:
2009         if system() == "Windows":
2010             return
2011
2012         # https://bugs.python.org/issue33660
2013         root = Path("/")
2014         with change_directory(root):
2015             path = Path("workspace") / "project"
2016             report = black.Report(verbose=True)
2017             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2018             self.assertEqual(normalized_path, "workspace/project")
2019
2020     def test_newline_comment_interaction(self) -> None:
2021         source = "class A:\\\r\n# type: ignore\n pass\n"
2022         output = black.format_str(source, mode=DEFAULT_MODE)
2023         black.assert_stable(source, output, mode=DEFAULT_MODE)
2024
2025     def test_bpo_2142_workaround(self) -> None:
2026
2027         # https://bugs.python.org/issue2142
2028
2029         source, _ = read_data("missing_final_newline.py")
2030         # read_data adds a trailing newline
2031         source = source.rstrip()
2032         expected, _ = read_data("missing_final_newline.diff")
2033         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2034         diff_header = re.compile(
2035             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2036             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2037         )
2038         try:
2039             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2040             self.assertEqual(result.exit_code, 0)
2041         finally:
2042             os.unlink(tmp_file)
2043         actual = result.output
2044         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2045         self.assertEqual(actual, expected)
2046
2047     @pytest.mark.python2
2048     def test_docstring_reformat_for_py27(self) -> None:
2049         """
2050         Check that stripping trailing whitespace from Python 2 docstrings
2051         doesn't trigger a "not equivalent to source" error
2052         """
2053         source = (
2054             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2055         )
2056         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2057
2058         result = CliRunner().invoke(
2059             black.main,
2060             ["-", "-q", "--target-version=py27"],
2061             input=BytesIO(source),
2062         )
2063
2064         self.assertEqual(result.exit_code, 0)
2065         actual = result.output
2066         self.assertFormatEqual(actual, expected)
2067
2068     @staticmethod
2069     def compare_results(
2070         result: click.testing.Result, expected_value: str, expected_exit_code: int
2071     ) -> None:
2072         """Helper method to test the value and exit code of a click Result."""
2073         assert (
2074             result.output == expected_value
2075         ), "The output did not match the expected value."
2076         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2077
2078     def test_code_option(self) -> None:
2079         """Test the code option with no changes."""
2080         code = 'print("Hello world")\n'
2081         args = ["--code", code]
2082         result = CliRunner().invoke(black.main, args)
2083
2084         self.compare_results(result, code, 0)
2085
2086     def test_code_option_changed(self) -> None:
2087         """Test the code option when changes are required."""
2088         code = "print('hello world')"
2089         formatted = black.format_str(code, mode=DEFAULT_MODE)
2090
2091         args = ["--code", code]
2092         result = CliRunner().invoke(black.main, args)
2093
2094         self.compare_results(result, formatted, 0)
2095
2096     def test_code_option_check(self) -> None:
2097         """Test the code option when check is passed."""
2098         args = ["--check", "--code", 'print("Hello world")\n']
2099         result = CliRunner().invoke(black.main, args)
2100         self.compare_results(result, "", 0)
2101
2102     def test_code_option_check_changed(self) -> None:
2103         """Test the code option when changes are required, and check is passed."""
2104         args = ["--check", "--code", "print('hello world')"]
2105         result = CliRunner().invoke(black.main, args)
2106         self.compare_results(result, "", 1)
2107
2108     def test_code_option_diff(self) -> None:
2109         """Test the code option when diff is passed."""
2110         code = "print('hello world')"
2111         formatted = black.format_str(code, mode=DEFAULT_MODE)
2112         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2113
2114         args = ["--diff", "--code", code]
2115         result = CliRunner().invoke(black.main, args)
2116
2117         # Remove time from diff
2118         output = DIFF_TIME.sub("", result.output)
2119
2120         assert output == result_diff, "The output did not match the expected value."
2121         assert result.exit_code == 0, "The exit code is incorrect."
2122
2123     def test_code_option_color_diff(self) -> None:
2124         """Test the code option when color and diff are passed."""
2125         code = "print('hello world')"
2126         formatted = black.format_str(code, mode=DEFAULT_MODE)
2127
2128         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2129         result_diff = color_diff(result_diff)
2130
2131         args = ["--diff", "--color", "--code", code]
2132         result = CliRunner().invoke(black.main, args)
2133
2134         # Remove time from diff
2135         output = DIFF_TIME.sub("", result.output)
2136
2137         assert output == result_diff, "The output did not match the expected value."
2138         assert result.exit_code == 0, "The exit code is incorrect."
2139
2140     def test_code_option_safe(self) -> None:
2141         """Test that the code option throws an error when the sanity checks fail."""
2142         # Patch black.assert_equivalent to ensure the sanity checks fail
2143         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2144             code = 'print("Hello world")'
2145             error_msg = f"{code}\nerror: cannot format <string>: \n"
2146
2147             args = ["--safe", "--code", code]
2148             result = CliRunner().invoke(black.main, args)
2149
2150             self.compare_results(result, error_msg, 123)
2151
2152     def test_code_option_fast(self) -> None:
2153         """Test that the code option ignores errors when the sanity checks fail."""
2154         # Patch black.assert_equivalent to ensure the sanity checks fail
2155         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2156             code = 'print("Hello world")'
2157             formatted = black.format_str(code, mode=DEFAULT_MODE)
2158
2159             args = ["--fast", "--code", code]
2160             result = CliRunner().invoke(black.main, args)
2161
2162             self.compare_results(result, formatted, 0)
2163
2164     def test_code_option_config(self) -> None:
2165         """
2166         Test that the code option finds the pyproject.toml in the current directory.
2167         """
2168         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2169             args = ["--code", "print"]
2170             CliRunner().invoke(black.main, args)
2171
2172             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2173             assert (
2174                 len(parse.mock_calls) >= 1
2175             ), "Expected config parse to be called with the current directory."
2176
2177             _, call_args, _ = parse.mock_calls[0]
2178             assert (
2179                 call_args[0].lower() == str(pyproject_path).lower()
2180             ), "Incorrect config loaded."
2181
2182     def test_code_option_parent_config(self) -> None:
2183         """
2184         Test that the code option finds the pyproject.toml in the parent directory.
2185         """
2186         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2187             with change_directory(Path("tests")):
2188                 args = ["--code", "print"]
2189                 CliRunner().invoke(black.main, args)
2190
2191                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2192                 assert (
2193                     len(parse.mock_calls) >= 1
2194                 ), "Expected config parse to be called with the current directory."
2195
2196                 _, call_args, _ = parse.mock_calls[0]
2197                 assert (
2198                     call_args[0].lower() == str(pyproject_path).lower()
2199                 ), "Incorrect config loaded."
2200
2201
2202 with open(black.__file__, "r", encoding="utf-8") as _bf:
2203     black_source_lines = _bf.readlines()
2204
2205
2206 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2207     """Show function calls `from black/__init__.py` as they happen.
2208
2209     Register this with `sys.settrace()` in a test you're debugging.
2210     """
2211     if event != "call":
2212         return tracefunc
2213
2214     stack = len(inspect.stack()) - 19
2215     stack *= 2
2216     filename = frame.f_code.co_filename
2217     lineno = frame.f_lineno
2218     func_sig_lineno = lineno - 1
2219     funcname = black_source_lines[func_sig_lineno].strip()
2220     while funcname.startswith("@"):
2221         func_sig_lineno += 1
2222         funcname = black_source_lines[func_sig_lineno].strip()
2223     if "black/__init__.py" in filename:
2224         print(f"{' ' * stack}{lineno}:{funcname}")
2225     return tracefunc
2226
2227
2228 if __name__ == "__main__":
2229     unittest.main(module="test_black")