]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use platformdirs over appdirs (#2375)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 import io
10 from io import BytesIO
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     Callable,
21     Dict,
22     List,
23     Iterator,
24     TypeVar,
25 )
26 import pytest
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36 from black.cache import get_cache_file
37 from black.debug import DebugVisitor
38 from black.output import diff, color_diff
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     read_data,
48     DETERMINISTIC_HEADER,
49     BlackBaseTestCase,
50     DEFAULT_MODE,
51     fs,
52     ff,
53     dump_to_stderr,
54 )
55
56
57 THIS_FILE = Path(__file__)
58 PY36_VERSIONS = {
59     TargetVersion.PY36,
60     TargetVersion.PY37,
61     TargetVersion.PY38,
62     TargetVersion.PY39,
63 }
64 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
65 T = TypeVar("T")
66 R = TypeVar("R")
67
68 # Match the time output in a diff, but nothing else
69 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
70
71
72 @contextmanager
73 def cache_dir(exists: bool = True) -> Iterator[Path]:
74     with TemporaryDirectory() as workspace:
75         cache_dir = Path(workspace)
76         if not exists:
77             cache_dir = cache_dir / "new"
78         with patch("black.cache.CACHE_DIR", cache_dir):
79             yield cache_dir
80
81
82 @contextmanager
83 def event_loop() -> Iterator[None]:
84     policy = asyncio.get_event_loop_policy()
85     loop = policy.new_event_loop()
86     asyncio.set_event_loop(loop)
87     try:
88         yield
89
90     finally:
91         loop.close()
92
93
94 class FakeContext(click.Context):
95     """A fake click Context for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         self.default_map: Dict[str, Any] = {}
99
100
101 class FakeParameter(click.Parameter):
102     """A fake click Parameter for when calling functions that need it."""
103
104     def __init__(self) -> None:
105         pass
106
107
108 class BlackRunner(CliRunner):
109     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
110
111     def __init__(self) -> None:
112         super().__init__(mix_stderr=False)
113
114
115 class BlackTestCase(BlackBaseTestCase):
116     def invokeBlack(
117         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
118     ) -> None:
119         runner = BlackRunner()
120         if ignore_config:
121             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
122         result = runner.invoke(black.main, args)
123         assert result.stdout_bytes is not None
124         assert result.stderr_bytes is not None
125         self.assertEqual(
126             result.exit_code,
127             exit_code,
128             msg=(
129                 f"Failed with args: {args}\n"
130                 f"stdout: {result.stdout_bytes.decode()!r}\n"
131                 f"stderr: {result.stderr_bytes.decode()!r}\n"
132                 f"exception: {result.exception}"
133             ),
134         )
135
136     @patch("black.dump_to_file", dump_to_stderr)
137     def test_empty(self) -> None:
138         source = expected = ""
139         actual = fs(source)
140         self.assertFormatEqual(expected, actual)
141         black.assert_equivalent(source, actual)
142         black.assert_stable(source, actual, DEFAULT_MODE)
143
144     def test_empty_ff(self) -> None:
145         expected = ""
146         tmp_file = Path(black.dump_to_file())
147         try:
148             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
149             with open(tmp_file, encoding="utf8") as f:
150                 actual = f.read()
151         finally:
152             os.unlink(tmp_file)
153         self.assertFormatEqual(expected, actual)
154
155     def test_piping(self) -> None:
156         source, expected = read_data("src/black/__init__", data=False)
157         result = BlackRunner().invoke(
158             black.main,
159             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
160             input=BytesIO(source.encode("utf8")),
161         )
162         self.assertEqual(result.exit_code, 0)
163         self.assertFormatEqual(expected, result.output)
164         if source != result.output:
165             black.assert_equivalent(source, result.output)
166             black.assert_stable(source, result.output, DEFAULT_MODE)
167
168     def test_piping_diff(self) -> None:
169         diff_header = re.compile(
170             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
171             r"\+\d\d\d\d"
172         )
173         source, _ = read_data("expression.py")
174         expected, _ = read_data("expression.diff")
175         config = THIS_DIR / "data" / "empty_pyproject.toml"
176         args = [
177             "-",
178             "--fast",
179             f"--line-length={black.DEFAULT_LINE_LENGTH}",
180             "--diff",
181             f"--config={config}",
182         ]
183         result = BlackRunner().invoke(
184             black.main, args, input=BytesIO(source.encode("utf8"))
185         )
186         self.assertEqual(result.exit_code, 0)
187         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
188         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
189         self.assertEqual(expected, actual)
190
191     def test_piping_diff_with_color(self) -> None:
192         source, _ = read_data("expression.py")
193         config = THIS_DIR / "data" / "empty_pyproject.toml"
194         args = [
195             "-",
196             "--fast",
197             f"--line-length={black.DEFAULT_LINE_LENGTH}",
198             "--diff",
199             "--color",
200             f"--config={config}",
201         ]
202         result = BlackRunner().invoke(
203             black.main, args, input=BytesIO(source.encode("utf8"))
204         )
205         actual = result.output
206         # Again, the contents are checked in a different test, so only look for colors.
207         self.assertIn("\033[1;37m", actual)
208         self.assertIn("\033[36m", actual)
209         self.assertIn("\033[32m", actual)
210         self.assertIn("\033[31m", actual)
211         self.assertIn("\033[0m", actual)
212
213     @patch("black.dump_to_file", dump_to_stderr)
214     def _test_wip(self) -> None:
215         source, expected = read_data("wip")
216         sys.settrace(tracefunc)
217         mode = replace(
218             DEFAULT_MODE,
219             experimental_string_processing=False,
220             target_versions={black.TargetVersion.PY38},
221         )
222         actual = fs(source, mode=mode)
223         sys.settrace(None)
224         self.assertFormatEqual(expected, actual)
225         black.assert_equivalent(source, actual)
226         black.assert_stable(source, actual, black.FileMode())
227
228     @unittest.expectedFailure
229     @patch("black.dump_to_file", dump_to_stderr)
230     def test_trailing_comma_optional_parens_stability1(self) -> None:
231         source, _expected = read_data("trailing_comma_optional_parens1")
232         actual = fs(source)
233         black.assert_stable(source, actual, DEFAULT_MODE)
234
235     @unittest.expectedFailure
236     @patch("black.dump_to_file", dump_to_stderr)
237     def test_trailing_comma_optional_parens_stability2(self) -> None:
238         source, _expected = read_data("trailing_comma_optional_parens2")
239         actual = fs(source)
240         black.assert_stable(source, actual, DEFAULT_MODE)
241
242     @unittest.expectedFailure
243     @patch("black.dump_to_file", dump_to_stderr)
244     def test_trailing_comma_optional_parens_stability3(self) -> None:
245         source, _expected = read_data("trailing_comma_optional_parens3")
246         actual = fs(source)
247         black.assert_stable(source, actual, DEFAULT_MODE)
248
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
251         source, _expected = read_data("trailing_comma_optional_parens1")
252         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
253         black.assert_stable(source, actual, DEFAULT_MODE)
254
255     @patch("black.dump_to_file", dump_to_stderr)
256     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
257         source, _expected = read_data("trailing_comma_optional_parens2")
258         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
259         black.assert_stable(source, actual, DEFAULT_MODE)
260
261     @patch("black.dump_to_file", dump_to_stderr)
262     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
263         source, _expected = read_data("trailing_comma_optional_parens3")
264         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
265         black.assert_stable(source, actual, DEFAULT_MODE)
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_pep_572(self) -> None:
269         source, expected = read_data("pep_572")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_stable(source, actual, DEFAULT_MODE)
273         if sys.version_info >= (3, 8):
274             black.assert_equivalent(source, actual)
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_pep_572_remove_parens(self) -> None:
278         source, expected = read_data("pep_572_remove_parens")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_stable(source, actual, DEFAULT_MODE)
282         if sys.version_info >= (3, 8):
283             black.assert_equivalent(source, actual)
284
285     @patch("black.dump_to_file", dump_to_stderr)
286     def test_pep_572_do_not_remove_parens(self) -> None:
287         source, expected = read_data("pep_572_do_not_remove_parens")
288         # the AST safety checks will fail, but that's expected, just make sure no
289         # parentheses are touched
290         actual = black.format_str(source, mode=DEFAULT_MODE)
291         self.assertFormatEqual(expected, actual)
292
293     def test_pep_572_version_detection(self) -> None:
294         source, _ = read_data("pep_572")
295         root = black.lib2to3_parse(source)
296         features = black.get_features_used(root)
297         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
298         versions = black.detect_target_versions(root)
299         self.assertIn(black.TargetVersion.PY38, versions)
300
301     def test_expression_ff(self) -> None:
302         source, expected = read_data("expression")
303         tmp_file = Path(black.dump_to_file(source))
304         try:
305             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
306             with open(tmp_file, encoding="utf8") as f:
307                 actual = f.read()
308         finally:
309             os.unlink(tmp_file)
310         self.assertFormatEqual(expected, actual)
311         with patch("black.dump_to_file", dump_to_stderr):
312             black.assert_equivalent(source, actual)
313             black.assert_stable(source, actual, DEFAULT_MODE)
314
315     def test_expression_diff(self) -> None:
316         source, _ = read_data("expression.py")
317         config = THIS_DIR / "data" / "empty_pyproject.toml"
318         expected, _ = read_data("expression.diff")
319         tmp_file = Path(black.dump_to_file(source))
320         diff_header = re.compile(
321             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
322             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
323         )
324         try:
325             result = BlackRunner().invoke(
326                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
327             )
328             self.assertEqual(result.exit_code, 0)
329         finally:
330             os.unlink(tmp_file)
331         actual = result.output
332         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
333         if expected != actual:
334             dump = black.dump_to_file(actual)
335             msg = (
336                 "Expected diff isn't equal to the actual. If you made changes to"
337                 " expression.py and this is an anticipated difference, overwrite"
338                 f" tests/data/expression.diff with {dump}"
339             )
340             self.assertEqual(expected, actual, msg)
341
342     def test_expression_diff_with_color(self) -> None:
343         source, _ = read_data("expression.py")
344         config = THIS_DIR / "data" / "empty_pyproject.toml"
345         expected, _ = read_data("expression.diff")
346         tmp_file = Path(black.dump_to_file(source))
347         try:
348             result = BlackRunner().invoke(
349                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
350             )
351         finally:
352             os.unlink(tmp_file)
353         actual = result.output
354         # We check the contents of the diff in `test_expression_diff`. All
355         # we need to check here is that color codes exist in the result.
356         self.assertIn("\033[1;37m", actual)
357         self.assertIn("\033[36m", actual)
358         self.assertIn("\033[32m", actual)
359         self.assertIn("\033[31m", actual)
360         self.assertIn("\033[0m", actual)
361
362     @patch("black.dump_to_file", dump_to_stderr)
363     def test_pep_570(self) -> None:
364         source, expected = read_data("pep_570")
365         actual = fs(source)
366         self.assertFormatEqual(expected, actual)
367         black.assert_stable(source, actual, DEFAULT_MODE)
368         if sys.version_info >= (3, 8):
369             black.assert_equivalent(source, actual)
370
371     def test_detect_pos_only_arguments(self) -> None:
372         source, _ = read_data("pep_570")
373         root = black.lib2to3_parse(source)
374         features = black.get_features_used(root)
375         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
376         versions = black.detect_target_versions(root)
377         self.assertIn(black.TargetVersion.PY38, versions)
378
379     @patch("black.dump_to_file", dump_to_stderr)
380     def test_string_quotes(self) -> None:
381         source, expected = read_data("string_quotes")
382         mode = black.Mode(experimental_string_processing=True)
383         actual = fs(source, mode=mode)
384         self.assertFormatEqual(expected, actual)
385         black.assert_equivalent(source, actual)
386         black.assert_stable(source, actual, mode)
387         mode = replace(mode, string_normalization=False)
388         not_normalized = fs(source, mode=mode)
389         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
390         black.assert_equivalent(source, not_normalized)
391         black.assert_stable(source, not_normalized, mode=mode)
392
393     @patch("black.dump_to_file", dump_to_stderr)
394     def test_docstring_no_string_normalization(self) -> None:
395         """Like test_docstring but with string normalization off."""
396         source, expected = read_data("docstring_no_string_normalization")
397         mode = replace(DEFAULT_MODE, string_normalization=False)
398         actual = fs(source, mode=mode)
399         self.assertFormatEqual(expected, actual)
400         black.assert_equivalent(source, actual)
401         black.assert_stable(source, actual, mode)
402
403     def test_long_strings_flag_disabled(self) -> None:
404         """Tests for turning off the string processing logic."""
405         source, expected = read_data("long_strings_flag_disabled")
406         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
407         actual = fs(source, mode=mode)
408         self.assertFormatEqual(expected, actual)
409         black.assert_stable(expected, actual, mode)
410
411     @patch("black.dump_to_file", dump_to_stderr)
412     def test_numeric_literals(self) -> None:
413         source, expected = read_data("numeric_literals")
414         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
415         actual = fs(source, mode=mode)
416         self.assertFormatEqual(expected, actual)
417         black.assert_equivalent(source, actual)
418         black.assert_stable(source, actual, mode)
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_numeric_literals_ignoring_underscores(self) -> None:
422         source, expected = read_data("numeric_literals_skip_underscores")
423         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
424         actual = fs(source, mode=mode)
425         self.assertFormatEqual(expected, actual)
426         black.assert_equivalent(source, actual)
427         black.assert_stable(source, actual, mode)
428
429     def test_skip_magic_trailing_comma(self) -> None:
430         source, _ = read_data("expression.py")
431         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
432         tmp_file = Path(black.dump_to_file(source))
433         diff_header = re.compile(
434             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
435             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
436         )
437         try:
438             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
439             self.assertEqual(result.exit_code, 0)
440         finally:
441             os.unlink(tmp_file)
442         actual = result.output
443         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
444         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
445         if expected != actual:
446             dump = black.dump_to_file(actual)
447             msg = (
448                 "Expected diff isn't equal to the actual. If you made changes to"
449                 " expression.py and this is an anticipated difference, overwrite"
450                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
451             )
452             self.assertEqual(expected, actual, msg)
453
454     @pytest.mark.python2
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_python2_print_function(self) -> None:
457         source, expected = read_data("python2_print_function")
458         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
459         actual = fs(source, mode=mode)
460         self.assertFormatEqual(expected, actual)
461         black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, mode)
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_stub(self) -> None:
466         mode = replace(DEFAULT_MODE, is_pyi=True)
467         source, expected = read_data("stub.pyi")
468         actual = fs(source, mode=mode)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, mode)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_async_as_identifier(self) -> None:
474         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
475         source, expected = read_data("async_as_identifier")
476         actual = fs(source)
477         self.assertFormatEqual(expected, actual)
478         major, minor = sys.version_info[:2]
479         if major < 3 or (major <= 3 and minor < 7):
480             black.assert_equivalent(source, actual)
481         black.assert_stable(source, actual, DEFAULT_MODE)
482         # ensure black can parse this when the target is 3.6
483         self.invokeBlack([str(source_path), "--target-version", "py36"])
484         # but not on 3.7, because async/await is no longer an identifier
485         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_python37(self) -> None:
489         source_path = (THIS_DIR / "data" / "python37.py").resolve()
490         source, expected = read_data("python37")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         major, minor = sys.version_info[:2]
494         if major > 3 or (major == 3 and minor >= 7):
495             black.assert_equivalent(source, actual)
496         black.assert_stable(source, actual, DEFAULT_MODE)
497         # ensure black can parse this when the target is 3.7
498         self.invokeBlack([str(source_path), "--target-version", "py37"])
499         # but not on 3.6, because we use async as a reserved keyword
500         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
501
502     @patch("black.dump_to_file", dump_to_stderr)
503     def test_python38(self) -> None:
504         source, expected = read_data("python38")
505         actual = fs(source)
506         self.assertFormatEqual(expected, actual)
507         major, minor = sys.version_info[:2]
508         if major > 3 or (major == 3 and minor >= 8):
509             black.assert_equivalent(source, actual)
510         black.assert_stable(source, actual, DEFAULT_MODE)
511
512     @patch("black.dump_to_file", dump_to_stderr)
513     def test_python39(self) -> None:
514         source, expected = read_data("python39")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         major, minor = sys.version_info[:2]
518         if major > 3 or (major == 3 and minor >= 9):
519             black.assert_equivalent(source, actual)
520         black.assert_stable(source, actual, DEFAULT_MODE)
521
522     def test_tab_comment_indentation(self) -> None:
523         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
524         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
525         self.assertFormatEqual(contents_spc, fs(contents_spc))
526         self.assertFormatEqual(contents_spc, fs(contents_tab))
527
528         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
529         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
530         self.assertFormatEqual(contents_spc, fs(contents_spc))
531         self.assertFormatEqual(contents_spc, fs(contents_tab))
532
533         # mixed tabs and spaces (valid Python 2 code)
534         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
535         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
536         self.assertFormatEqual(contents_spc, fs(contents_spc))
537         self.assertFormatEqual(contents_spc, fs(contents_tab))
538
539         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
540         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
541         self.assertFormatEqual(contents_spc, fs(contents_spc))
542         self.assertFormatEqual(contents_spc, fs(contents_tab))
543
544     def test_report_verbose(self) -> None:
545         report = Report(verbose=True)
546         out_lines = []
547         err_lines = []
548
549         def out(msg: str, **kwargs: Any) -> None:
550             out_lines.append(msg)
551
552         def err(msg: str, **kwargs: Any) -> None:
553             err_lines.append(msg)
554
555         with patch("black.output._out", out), patch("black.output._err", err):
556             report.done(Path("f1"), black.Changed.NO)
557             self.assertEqual(len(out_lines), 1)
558             self.assertEqual(len(err_lines), 0)
559             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
560             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
561             self.assertEqual(report.return_code, 0)
562             report.done(Path("f2"), black.Changed.YES)
563             self.assertEqual(len(out_lines), 2)
564             self.assertEqual(len(err_lines), 0)
565             self.assertEqual(out_lines[-1], "reformatted f2")
566             self.assertEqual(
567                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
568             )
569             report.done(Path("f3"), black.Changed.CACHED)
570             self.assertEqual(len(out_lines), 3)
571             self.assertEqual(len(err_lines), 0)
572             self.assertEqual(
573                 out_lines[-1], "f3 wasn't modified on disk since last run."
574             )
575             self.assertEqual(
576                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
577             )
578             self.assertEqual(report.return_code, 0)
579             report.check = True
580             self.assertEqual(report.return_code, 1)
581             report.check = False
582             report.failed(Path("e1"), "boom")
583             self.assertEqual(len(out_lines), 3)
584             self.assertEqual(len(err_lines), 1)
585             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.done(Path("f3"), black.Changed.YES)
593             self.assertEqual(len(out_lines), 4)
594             self.assertEqual(len(err_lines), 1)
595             self.assertEqual(out_lines[-1], "reformatted f3")
596             self.assertEqual(
597                 unstyle(str(report)),
598                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
599                 " reformat.",
600             )
601             self.assertEqual(report.return_code, 123)
602             report.failed(Path("e2"), "boom")
603             self.assertEqual(len(out_lines), 4)
604             self.assertEqual(len(err_lines), 2)
605             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
606             self.assertEqual(
607                 unstyle(str(report)),
608                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
609                 " reformat.",
610             )
611             self.assertEqual(report.return_code, 123)
612             report.path_ignored(Path("wat"), "no match")
613             self.assertEqual(len(out_lines), 5)
614             self.assertEqual(len(err_lines), 2)
615             self.assertEqual(out_lines[-1], "wat ignored: no match")
616             self.assertEqual(
617                 unstyle(str(report)),
618                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
619                 " reformat.",
620             )
621             self.assertEqual(report.return_code, 123)
622             report.done(Path("f4"), black.Changed.NO)
623             self.assertEqual(len(out_lines), 6)
624             self.assertEqual(len(err_lines), 2)
625             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
626             self.assertEqual(
627                 unstyle(str(report)),
628                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
629                 " reformat.",
630             )
631             self.assertEqual(report.return_code, 123)
632             report.check = True
633             self.assertEqual(
634                 unstyle(str(report)),
635                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
636                 " would fail to reformat.",
637             )
638             report.check = False
639             report.diff = True
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
643                 " would fail to reformat.",
644             )
645
646     def test_report_quiet(self) -> None:
647         report = Report(quiet=True)
648         out_lines = []
649         err_lines = []
650
651         def out(msg: str, **kwargs: Any) -> None:
652             out_lines.append(msg)
653
654         def err(msg: str, **kwargs: Any) -> None:
655             err_lines.append(msg)
656
657         with patch("black.output._out", out), patch("black.output._err", err):
658             report.done(Path("f1"), black.Changed.NO)
659             self.assertEqual(len(out_lines), 0)
660             self.assertEqual(len(err_lines), 0)
661             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
662             self.assertEqual(report.return_code, 0)
663             report.done(Path("f2"), black.Changed.YES)
664             self.assertEqual(len(out_lines), 0)
665             self.assertEqual(len(err_lines), 0)
666             self.assertEqual(
667                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
668             )
669             report.done(Path("f3"), black.Changed.CACHED)
670             self.assertEqual(len(out_lines), 0)
671             self.assertEqual(len(err_lines), 0)
672             self.assertEqual(
673                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
674             )
675             self.assertEqual(report.return_code, 0)
676             report.check = True
677             self.assertEqual(report.return_code, 1)
678             report.check = False
679             report.failed(Path("e1"), "boom")
680             self.assertEqual(len(out_lines), 0)
681             self.assertEqual(len(err_lines), 1)
682             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.done(Path("f3"), black.Changed.YES)
690             self.assertEqual(len(out_lines), 0)
691             self.assertEqual(len(err_lines), 1)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.failed(Path("e2"), "boom")
699             self.assertEqual(len(out_lines), 0)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
702             self.assertEqual(
703                 unstyle(str(report)),
704                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
705                 " reformat.",
706             )
707             self.assertEqual(report.return_code, 123)
708             report.path_ignored(Path("wat"), "no match")
709             self.assertEqual(len(out_lines), 0)
710             self.assertEqual(len(err_lines), 2)
711             self.assertEqual(
712                 unstyle(str(report)),
713                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
714                 " reformat.",
715             )
716             self.assertEqual(report.return_code, 123)
717             report.done(Path("f4"), black.Changed.NO)
718             self.assertEqual(len(out_lines), 0)
719             self.assertEqual(len(err_lines), 2)
720             self.assertEqual(
721                 unstyle(str(report)),
722                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
723                 " reformat.",
724             )
725             self.assertEqual(report.return_code, 123)
726             report.check = True
727             self.assertEqual(
728                 unstyle(str(report)),
729                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
730                 " would fail to reformat.",
731             )
732             report.check = False
733             report.diff = True
734             self.assertEqual(
735                 unstyle(str(report)),
736                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
737                 " would fail to reformat.",
738             )
739
740     def test_report_normal(self) -> None:
741         report = black.Report()
742         out_lines = []
743         err_lines = []
744
745         def out(msg: str, **kwargs: Any) -> None:
746             out_lines.append(msg)
747
748         def err(msg: str, **kwargs: Any) -> None:
749             err_lines.append(msg)
750
751         with patch("black.output._out", out), patch("black.output._err", err):
752             report.done(Path("f1"), black.Changed.NO)
753             self.assertEqual(len(out_lines), 0)
754             self.assertEqual(len(err_lines), 0)
755             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
756             self.assertEqual(report.return_code, 0)
757             report.done(Path("f2"), black.Changed.YES)
758             self.assertEqual(len(out_lines), 1)
759             self.assertEqual(len(err_lines), 0)
760             self.assertEqual(out_lines[-1], "reformatted f2")
761             self.assertEqual(
762                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
763             )
764             report.done(Path("f3"), black.Changed.CACHED)
765             self.assertEqual(len(out_lines), 1)
766             self.assertEqual(len(err_lines), 0)
767             self.assertEqual(out_lines[-1], "reformatted f2")
768             self.assertEqual(
769                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
770             )
771             self.assertEqual(report.return_code, 0)
772             report.check = True
773             self.assertEqual(report.return_code, 1)
774             report.check = False
775             report.failed(Path("e1"), "boom")
776             self.assertEqual(len(out_lines), 1)
777             self.assertEqual(len(err_lines), 1)
778             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
779             self.assertEqual(
780                 unstyle(str(report)),
781                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
782                 " reformat.",
783             )
784             self.assertEqual(report.return_code, 123)
785             report.done(Path("f3"), black.Changed.YES)
786             self.assertEqual(len(out_lines), 2)
787             self.assertEqual(len(err_lines), 1)
788             self.assertEqual(out_lines[-1], "reformatted f3")
789             self.assertEqual(
790                 unstyle(str(report)),
791                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
792                 " reformat.",
793             )
794             self.assertEqual(report.return_code, 123)
795             report.failed(Path("e2"), "boom")
796             self.assertEqual(len(out_lines), 2)
797             self.assertEqual(len(err_lines), 2)
798             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
799             self.assertEqual(
800                 unstyle(str(report)),
801                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
802                 " reformat.",
803             )
804             self.assertEqual(report.return_code, 123)
805             report.path_ignored(Path("wat"), "no match")
806             self.assertEqual(len(out_lines), 2)
807             self.assertEqual(len(err_lines), 2)
808             self.assertEqual(
809                 unstyle(str(report)),
810                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
811                 " reformat.",
812             )
813             self.assertEqual(report.return_code, 123)
814             report.done(Path("f4"), black.Changed.NO)
815             self.assertEqual(len(out_lines), 2)
816             self.assertEqual(len(err_lines), 2)
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
820                 " reformat.",
821             )
822             self.assertEqual(report.return_code, 123)
823             report.check = True
824             self.assertEqual(
825                 unstyle(str(report)),
826                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
827                 " would fail to reformat.",
828             )
829             report.check = False
830             report.diff = True
831             self.assertEqual(
832                 unstyle(str(report)),
833                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
834                 " would fail to reformat.",
835             )
836
837     def test_lib2to3_parse(self) -> None:
838         with self.assertRaises(black.InvalidInput):
839             black.lib2to3_parse("invalid syntax")
840
841         straddling = "x + y"
842         black.lib2to3_parse(straddling)
843         black.lib2to3_parse(straddling, {TargetVersion.PY27})
844         black.lib2to3_parse(straddling, {TargetVersion.PY36})
845         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
846
847         py2_only = "print x"
848         black.lib2to3_parse(py2_only)
849         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
850         with self.assertRaises(black.InvalidInput):
851             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
852         with self.assertRaises(black.InvalidInput):
853             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
854
855         py3_only = "exec(x, end=y)"
856         black.lib2to3_parse(py3_only)
857         with self.assertRaises(black.InvalidInput):
858             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
859         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
860         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
861
862     def test_get_features_used_decorator(self) -> None:
863         # Test the feature detection of new decorator syntax
864         # since this makes some test cases of test_get_features_used()
865         # fails if it fails, this is tested first so that a useful case
866         # is identified
867         simples, relaxed = read_data("decorators")
868         # skip explanation comments at the top of the file
869         for simple_test in simples.split("##")[1:]:
870             node = black.lib2to3_parse(simple_test)
871             decorator = str(node.children[0].children[0]).strip()
872             self.assertNotIn(
873                 Feature.RELAXED_DECORATORS,
874                 black.get_features_used(node),
875                 msg=(
876                     f"decorator '{decorator}' follows python<=3.8 syntax"
877                     "but is detected as 3.9+"
878                     # f"The full node is\n{node!r}"
879                 ),
880             )
881         # skip the '# output' comment at the top of the output part
882         for relaxed_test in relaxed.split("##")[1:]:
883             node = black.lib2to3_parse(relaxed_test)
884             decorator = str(node.children[0].children[0]).strip()
885             self.assertIn(
886                 Feature.RELAXED_DECORATORS,
887                 black.get_features_used(node),
888                 msg=(
889                     f"decorator '{decorator}' uses python3.9+ syntax"
890                     "but is detected as python<=3.8"
891                     # f"The full node is\n{node!r}"
892                 ),
893             )
894
895     def test_get_features_used(self) -> None:
896         node = black.lib2to3_parse("def f(*, arg): ...\n")
897         self.assertEqual(black.get_features_used(node), set())
898         node = black.lib2to3_parse("def f(*, arg,): ...\n")
899         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
900         node = black.lib2to3_parse("f(*arg,)\n")
901         self.assertEqual(
902             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
903         )
904         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
905         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
906         node = black.lib2to3_parse("123_456\n")
907         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
908         node = black.lib2to3_parse("123456\n")
909         self.assertEqual(black.get_features_used(node), set())
910         source, expected = read_data("function")
911         node = black.lib2to3_parse(source)
912         expected_features = {
913             Feature.TRAILING_COMMA_IN_CALL,
914             Feature.TRAILING_COMMA_IN_DEF,
915             Feature.F_STRINGS,
916         }
917         self.assertEqual(black.get_features_used(node), expected_features)
918         node = black.lib2to3_parse(expected)
919         self.assertEqual(black.get_features_used(node), expected_features)
920         source, expected = read_data("expression")
921         node = black.lib2to3_parse(source)
922         self.assertEqual(black.get_features_used(node), set())
923         node = black.lib2to3_parse(expected)
924         self.assertEqual(black.get_features_used(node), set())
925
926     def test_get_future_imports(self) -> None:
927         node = black.lib2to3_parse("\n")
928         self.assertEqual(set(), black.get_future_imports(node))
929         node = black.lib2to3_parse("from __future__ import black\n")
930         self.assertEqual({"black"}, black.get_future_imports(node))
931         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
932         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
933         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
934         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
935         node = black.lib2to3_parse(
936             "from __future__ import multiple\nfrom __future__ import imports\n"
937         )
938         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
939         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
940         self.assertEqual({"black"}, black.get_future_imports(node))
941         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
942         self.assertEqual({"black"}, black.get_future_imports(node))
943         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
944         self.assertEqual(set(), black.get_future_imports(node))
945         node = black.lib2to3_parse("from some.module import black\n")
946         self.assertEqual(set(), black.get_future_imports(node))
947         node = black.lib2to3_parse(
948             "from __future__ import unicode_literals as _unicode_literals"
949         )
950         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
951         node = black.lib2to3_parse(
952             "from __future__ import unicode_literals as _lol, print"
953         )
954         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
955
956     def test_debug_visitor(self) -> None:
957         source, _ = read_data("debug_visitor.py")
958         expected, _ = read_data("debug_visitor.out")
959         out_lines = []
960         err_lines = []
961
962         def out(msg: str, **kwargs: Any) -> None:
963             out_lines.append(msg)
964
965         def err(msg: str, **kwargs: Any) -> None:
966             err_lines.append(msg)
967
968         with patch("black.debug.out", out):
969             DebugVisitor.show(source)
970         actual = "\n".join(out_lines) + "\n"
971         log_name = ""
972         if expected != actual:
973             log_name = black.dump_to_file(*out_lines)
974         self.assertEqual(
975             expected,
976             actual,
977             f"AST print out is different. Actual version dumped to {log_name}",
978         )
979
980     def test_format_file_contents(self) -> None:
981         empty = ""
982         mode = DEFAULT_MODE
983         with self.assertRaises(black.NothingChanged):
984             black.format_file_contents(empty, mode=mode, fast=False)
985         just_nl = "\n"
986         with self.assertRaises(black.NothingChanged):
987             black.format_file_contents(just_nl, mode=mode, fast=False)
988         same = "j = [1, 2, 3]\n"
989         with self.assertRaises(black.NothingChanged):
990             black.format_file_contents(same, mode=mode, fast=False)
991         different = "j = [1,2,3]"
992         expected = same
993         actual = black.format_file_contents(different, mode=mode, fast=False)
994         self.assertEqual(expected, actual)
995         invalid = "return if you can"
996         with self.assertRaises(black.InvalidInput) as e:
997             black.format_file_contents(invalid, mode=mode, fast=False)
998         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
999
1000     def test_endmarker(self) -> None:
1001         n = black.lib2to3_parse("\n")
1002         self.assertEqual(n.type, black.syms.file_input)
1003         self.assertEqual(len(n.children), 1)
1004         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1005
1006     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1007     def test_assertFormatEqual(self) -> None:
1008         out_lines = []
1009         err_lines = []
1010
1011         def out(msg: str, **kwargs: Any) -> None:
1012             out_lines.append(msg)
1013
1014         def err(msg: str, **kwargs: Any) -> None:
1015             err_lines.append(msg)
1016
1017         with patch("black.output._out", out), patch("black.output._err", err):
1018             with self.assertRaises(AssertionError):
1019                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1020
1021         out_str = "".join(out_lines)
1022         self.assertTrue("Expected tree:" in out_str)
1023         self.assertTrue("Actual tree:" in out_str)
1024         self.assertEqual("".join(err_lines), "")
1025
1026     def test_cache_broken_file(self) -> None:
1027         mode = DEFAULT_MODE
1028         with cache_dir() as workspace:
1029             cache_file = get_cache_file(mode)
1030             with cache_file.open("w") as fobj:
1031                 fobj.write("this is not a pickle")
1032             self.assertEqual(black.read_cache(mode), {})
1033             src = (workspace / "test.py").resolve()
1034             with src.open("w") as fobj:
1035                 fobj.write("print('hello')")
1036             self.invokeBlack([str(src)])
1037             cache = black.read_cache(mode)
1038             self.assertIn(str(src), cache)
1039
1040     def test_cache_single_file_already_cached(self) -> None:
1041         mode = DEFAULT_MODE
1042         with cache_dir() as workspace:
1043             src = (workspace / "test.py").resolve()
1044             with src.open("w") as fobj:
1045                 fobj.write("print('hello')")
1046             black.write_cache({}, [src], mode)
1047             self.invokeBlack([str(src)])
1048             with src.open("r") as fobj:
1049                 self.assertEqual(fobj.read(), "print('hello')")
1050
1051     @event_loop()
1052     def test_cache_multiple_files(self) -> None:
1053         mode = DEFAULT_MODE
1054         with cache_dir() as workspace, patch(
1055             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1056         ):
1057             one = (workspace / "one.py").resolve()
1058             with one.open("w") as fobj:
1059                 fobj.write("print('hello')")
1060             two = (workspace / "two.py").resolve()
1061             with two.open("w") as fobj:
1062                 fobj.write("print('hello')")
1063             black.write_cache({}, [one], mode)
1064             self.invokeBlack([str(workspace)])
1065             with one.open("r") as fobj:
1066                 self.assertEqual(fobj.read(), "print('hello')")
1067             with two.open("r") as fobj:
1068                 self.assertEqual(fobj.read(), 'print("hello")\n')
1069             cache = black.read_cache(mode)
1070             self.assertIn(str(one), cache)
1071             self.assertIn(str(two), cache)
1072
1073     def test_no_cache_when_writeback_diff(self) -> None:
1074         mode = DEFAULT_MODE
1075         with cache_dir() as workspace:
1076             src = (workspace / "test.py").resolve()
1077             with src.open("w") as fobj:
1078                 fobj.write("print('hello')")
1079             with patch("black.read_cache") as read_cache, patch(
1080                 "black.write_cache"
1081             ) as write_cache:
1082                 self.invokeBlack([str(src), "--diff"])
1083                 cache_file = get_cache_file(mode)
1084                 self.assertFalse(cache_file.exists())
1085                 write_cache.assert_not_called()
1086                 read_cache.assert_not_called()
1087
1088     def test_no_cache_when_writeback_color_diff(self) -> None:
1089         mode = DEFAULT_MODE
1090         with cache_dir() as workspace:
1091             src = (workspace / "test.py").resolve()
1092             with src.open("w") as fobj:
1093                 fobj.write("print('hello')")
1094             with patch("black.read_cache") as read_cache, patch(
1095                 "black.write_cache"
1096             ) as write_cache:
1097                 self.invokeBlack([str(src), "--diff", "--color"])
1098                 cache_file = get_cache_file(mode)
1099                 self.assertFalse(cache_file.exists())
1100                 write_cache.assert_not_called()
1101                 read_cache.assert_not_called()
1102
1103     @event_loop()
1104     def test_output_locking_when_writeback_diff(self) -> None:
1105         with cache_dir() as workspace:
1106             for tag in range(0, 4):
1107                 src = (workspace / f"test{tag}.py").resolve()
1108                 with src.open("w") as fobj:
1109                     fobj.write("print('hello')")
1110             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1111                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1112                 # this isn't quite doing what we want, but if it _isn't_
1113                 # called then we cannot be using the lock it provides
1114                 mgr.assert_called()
1115
1116     @event_loop()
1117     def test_output_locking_when_writeback_color_diff(self) -> None:
1118         with cache_dir() as workspace:
1119             for tag in range(0, 4):
1120                 src = (workspace / f"test{tag}.py").resolve()
1121                 with src.open("w") as fobj:
1122                     fobj.write("print('hello')")
1123             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1124                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1125                 # this isn't quite doing what we want, but if it _isn't_
1126                 # called then we cannot be using the lock it provides
1127                 mgr.assert_called()
1128
1129     def test_no_cache_when_stdin(self) -> None:
1130         mode = DEFAULT_MODE
1131         with cache_dir():
1132             result = CliRunner().invoke(
1133                 black.main, ["-"], input=BytesIO(b"print('hello')")
1134             )
1135             self.assertEqual(result.exit_code, 0)
1136             cache_file = get_cache_file(mode)
1137             self.assertFalse(cache_file.exists())
1138
1139     def test_read_cache_no_cachefile(self) -> None:
1140         mode = DEFAULT_MODE
1141         with cache_dir():
1142             self.assertEqual(black.read_cache(mode), {})
1143
1144     def test_write_cache_read_cache(self) -> None:
1145         mode = DEFAULT_MODE
1146         with cache_dir() as workspace:
1147             src = (workspace / "test.py").resolve()
1148             src.touch()
1149             black.write_cache({}, [src], mode)
1150             cache = black.read_cache(mode)
1151             self.assertIn(str(src), cache)
1152             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1153
1154     def test_filter_cached(self) -> None:
1155         with TemporaryDirectory() as workspace:
1156             path = Path(workspace)
1157             uncached = (path / "uncached").resolve()
1158             cached = (path / "cached").resolve()
1159             cached_but_changed = (path / "changed").resolve()
1160             uncached.touch()
1161             cached.touch()
1162             cached_but_changed.touch()
1163             cache = {
1164                 str(cached): black.get_cache_info(cached),
1165                 str(cached_but_changed): (0.0, 0),
1166             }
1167             todo, done = black.filter_cached(
1168                 cache, {uncached, cached, cached_but_changed}
1169             )
1170             self.assertEqual(todo, {uncached, cached_but_changed})
1171             self.assertEqual(done, {cached})
1172
1173     def test_write_cache_creates_directory_if_needed(self) -> None:
1174         mode = DEFAULT_MODE
1175         with cache_dir(exists=False) as workspace:
1176             self.assertFalse(workspace.exists())
1177             black.write_cache({}, [], mode)
1178             self.assertTrue(workspace.exists())
1179
1180     @event_loop()
1181     def test_failed_formatting_does_not_get_cached(self) -> None:
1182         mode = DEFAULT_MODE
1183         with cache_dir() as workspace, patch(
1184             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1185         ):
1186             failing = (workspace / "failing.py").resolve()
1187             with failing.open("w") as fobj:
1188                 fobj.write("not actually python")
1189             clean = (workspace / "clean.py").resolve()
1190             with clean.open("w") as fobj:
1191                 fobj.write('print("hello")\n')
1192             self.invokeBlack([str(workspace)], exit_code=123)
1193             cache = black.read_cache(mode)
1194             self.assertNotIn(str(failing), cache)
1195             self.assertIn(str(clean), cache)
1196
1197     def test_write_cache_write_fail(self) -> None:
1198         mode = DEFAULT_MODE
1199         with cache_dir(), patch.object(Path, "open") as mock:
1200             mock.side_effect = OSError
1201             black.write_cache({}, [], mode)
1202
1203     @event_loop()
1204     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1205     def test_works_in_mono_process_only_environment(self) -> None:
1206         with cache_dir() as workspace:
1207             for f in [
1208                 (workspace / "one.py").resolve(),
1209                 (workspace / "two.py").resolve(),
1210             ]:
1211                 f.write_text('print("hello")\n')
1212             self.invokeBlack([str(workspace)])
1213
1214     @event_loop()
1215     def test_check_diff_use_together(self) -> None:
1216         with cache_dir():
1217             # Files which will be reformatted.
1218             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1219             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1220             # Files which will not be reformatted.
1221             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1222             self.invokeBlack([str(src2), "--diff", "--check"])
1223             # Multi file command.
1224             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1225
1226     def test_no_files(self) -> None:
1227         with cache_dir():
1228             # Without an argument, black exits with error code 0.
1229             self.invokeBlack([])
1230
1231     def test_broken_symlink(self) -> None:
1232         with cache_dir() as workspace:
1233             symlink = workspace / "broken_link.py"
1234             try:
1235                 symlink.symlink_to("nonexistent.py")
1236             except OSError as e:
1237                 self.skipTest(f"Can't create symlinks: {e}")
1238             self.invokeBlack([str(workspace.resolve())])
1239
1240     def test_read_cache_line_lengths(self) -> None:
1241         mode = DEFAULT_MODE
1242         short_mode = replace(DEFAULT_MODE, line_length=1)
1243         with cache_dir() as workspace:
1244             path = (workspace / "file.py").resolve()
1245             path.touch()
1246             black.write_cache({}, [path], mode)
1247             one = black.read_cache(mode)
1248             self.assertIn(str(path), one)
1249             two = black.read_cache(short_mode)
1250             self.assertNotIn(str(path), two)
1251
1252     def test_single_file_force_pyi(self) -> None:
1253         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1254         contents, expected = read_data("force_pyi")
1255         with cache_dir() as workspace:
1256             path = (workspace / "file.py").resolve()
1257             with open(path, "w") as fh:
1258                 fh.write(contents)
1259             self.invokeBlack([str(path), "--pyi"])
1260             with open(path, "r") as fh:
1261                 actual = fh.read()
1262             # verify cache with --pyi is separate
1263             pyi_cache = black.read_cache(pyi_mode)
1264             self.assertIn(str(path), pyi_cache)
1265             normal_cache = black.read_cache(DEFAULT_MODE)
1266             self.assertNotIn(str(path), normal_cache)
1267         self.assertFormatEqual(expected, actual)
1268         black.assert_equivalent(contents, actual)
1269         black.assert_stable(contents, actual, pyi_mode)
1270
1271     @event_loop()
1272     def test_multi_file_force_pyi(self) -> None:
1273         reg_mode = DEFAULT_MODE
1274         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1275         contents, expected = read_data("force_pyi")
1276         with cache_dir() as workspace:
1277             paths = [
1278                 (workspace / "file1.py").resolve(),
1279                 (workspace / "file2.py").resolve(),
1280             ]
1281             for path in paths:
1282                 with open(path, "w") as fh:
1283                     fh.write(contents)
1284             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1285             for path in paths:
1286                 with open(path, "r") as fh:
1287                     actual = fh.read()
1288                 self.assertEqual(actual, expected)
1289             # verify cache with --pyi is separate
1290             pyi_cache = black.read_cache(pyi_mode)
1291             normal_cache = black.read_cache(reg_mode)
1292             for path in paths:
1293                 self.assertIn(str(path), pyi_cache)
1294                 self.assertNotIn(str(path), normal_cache)
1295
1296     def test_pipe_force_pyi(self) -> None:
1297         source, expected = read_data("force_pyi")
1298         result = CliRunner().invoke(
1299             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1300         )
1301         self.assertEqual(result.exit_code, 0)
1302         actual = result.output
1303         self.assertFormatEqual(actual, expected)
1304
1305     def test_single_file_force_py36(self) -> None:
1306         reg_mode = DEFAULT_MODE
1307         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1308         source, expected = read_data("force_py36")
1309         with cache_dir() as workspace:
1310             path = (workspace / "file.py").resolve()
1311             with open(path, "w") as fh:
1312                 fh.write(source)
1313             self.invokeBlack([str(path), *PY36_ARGS])
1314             with open(path, "r") as fh:
1315                 actual = fh.read()
1316             # verify cache with --target-version is separate
1317             py36_cache = black.read_cache(py36_mode)
1318             self.assertIn(str(path), py36_cache)
1319             normal_cache = black.read_cache(reg_mode)
1320             self.assertNotIn(str(path), normal_cache)
1321         self.assertEqual(actual, expected)
1322
1323     @event_loop()
1324     def test_multi_file_force_py36(self) -> None:
1325         reg_mode = DEFAULT_MODE
1326         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1327         source, expected = read_data("force_py36")
1328         with cache_dir() as workspace:
1329             paths = [
1330                 (workspace / "file1.py").resolve(),
1331                 (workspace / "file2.py").resolve(),
1332             ]
1333             for path in paths:
1334                 with open(path, "w") as fh:
1335                     fh.write(source)
1336             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1337             for path in paths:
1338                 with open(path, "r") as fh:
1339                     actual = fh.read()
1340                 self.assertEqual(actual, expected)
1341             # verify cache with --target-version is separate
1342             pyi_cache = black.read_cache(py36_mode)
1343             normal_cache = black.read_cache(reg_mode)
1344             for path in paths:
1345                 self.assertIn(str(path), pyi_cache)
1346                 self.assertNotIn(str(path), normal_cache)
1347
1348     def test_pipe_force_py36(self) -> None:
1349         source, expected = read_data("force_py36")
1350         result = CliRunner().invoke(
1351             black.main,
1352             ["-", "-q", "--target-version=py36"],
1353             input=BytesIO(source.encode("utf8")),
1354         )
1355         self.assertEqual(result.exit_code, 0)
1356         actual = result.output
1357         self.assertFormatEqual(actual, expected)
1358
1359     def test_include_exclude(self) -> None:
1360         path = THIS_DIR / "data" / "include_exclude_tests"
1361         include = re.compile(r"\.pyi?$")
1362         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1363         report = black.Report()
1364         gitignore = PathSpec.from_lines("gitwildmatch", [])
1365         sources: List[Path] = []
1366         expected = [
1367             Path(path / "b/dont_exclude/a.py"),
1368             Path(path / "b/dont_exclude/a.pyi"),
1369         ]
1370         this_abs = THIS_DIR.resolve()
1371         sources.extend(
1372             black.gen_python_files(
1373                 path.iterdir(),
1374                 this_abs,
1375                 include,
1376                 exclude,
1377                 None,
1378                 None,
1379                 report,
1380                 gitignore,
1381             )
1382         )
1383         self.assertEqual(sorted(expected), sorted(sources))
1384
1385     def test_gitignore_used_as_default(self) -> None:
1386         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1387         include = re.compile(r"\.pyi?$")
1388         extend_exclude = re.compile(r"/exclude/")
1389         src = str(path / "b/")
1390         report = black.Report()
1391         expected: List[Path] = [
1392             path / "b/.definitely_exclude/a.py",
1393             path / "b/.definitely_exclude/a.pyi",
1394         ]
1395         sources = list(
1396             black.get_sources(
1397                 ctx=FakeContext(),
1398                 src=(src,),
1399                 quiet=True,
1400                 verbose=False,
1401                 include=include,
1402                 exclude=None,
1403                 extend_exclude=extend_exclude,
1404                 force_exclude=None,
1405                 report=report,
1406                 stdin_filename=None,
1407             )
1408         )
1409         self.assertEqual(sorted(expected), sorted(sources))
1410
1411     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1412     def test_exclude_for_issue_1572(self) -> None:
1413         # Exclude shouldn't touch files that were explicitly given to Black through the
1414         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1415         # https://github.com/psf/black/issues/1572
1416         path = THIS_DIR / "data" / "include_exclude_tests"
1417         include = ""
1418         exclude = r"/exclude/|a\.py"
1419         src = str(path / "b/exclude/a.py")
1420         report = black.Report()
1421         expected = [Path(path / "b/exclude/a.py")]
1422         sources = list(
1423             black.get_sources(
1424                 ctx=FakeContext(),
1425                 src=(src,),
1426                 quiet=True,
1427                 verbose=False,
1428                 include=re.compile(include),
1429                 exclude=re.compile(exclude),
1430                 extend_exclude=None,
1431                 force_exclude=None,
1432                 report=report,
1433                 stdin_filename=None,
1434             )
1435         )
1436         self.assertEqual(sorted(expected), sorted(sources))
1437
1438     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1439     def test_get_sources_with_stdin(self) -> None:
1440         include = ""
1441         exclude = r"/exclude/|a\.py"
1442         src = "-"
1443         report = black.Report()
1444         expected = [Path("-")]
1445         sources = list(
1446             black.get_sources(
1447                 ctx=FakeContext(),
1448                 src=(src,),
1449                 quiet=True,
1450                 verbose=False,
1451                 include=re.compile(include),
1452                 exclude=re.compile(exclude),
1453                 extend_exclude=None,
1454                 force_exclude=None,
1455                 report=report,
1456                 stdin_filename=None,
1457             )
1458         )
1459         self.assertEqual(sorted(expected), sorted(sources))
1460
1461     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1462     def test_get_sources_with_stdin_filename(self) -> None:
1463         include = ""
1464         exclude = r"/exclude/|a\.py"
1465         src = "-"
1466         report = black.Report()
1467         stdin_filename = str(THIS_DIR / "data/collections.py")
1468         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1469         sources = list(
1470             black.get_sources(
1471                 ctx=FakeContext(),
1472                 src=(src,),
1473                 quiet=True,
1474                 verbose=False,
1475                 include=re.compile(include),
1476                 exclude=re.compile(exclude),
1477                 extend_exclude=None,
1478                 force_exclude=None,
1479                 report=report,
1480                 stdin_filename=stdin_filename,
1481             )
1482         )
1483         self.assertEqual(sorted(expected), sorted(sources))
1484
1485     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1486     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1487         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1488         # file being passed directly. This is the same as
1489         # test_exclude_for_issue_1572
1490         path = THIS_DIR / "data" / "include_exclude_tests"
1491         include = ""
1492         exclude = r"/exclude/|a\.py"
1493         src = "-"
1494         report = black.Report()
1495         stdin_filename = str(path / "b/exclude/a.py")
1496         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1497         sources = list(
1498             black.get_sources(
1499                 ctx=FakeContext(),
1500                 src=(src,),
1501                 quiet=True,
1502                 verbose=False,
1503                 include=re.compile(include),
1504                 exclude=re.compile(exclude),
1505                 extend_exclude=None,
1506                 force_exclude=None,
1507                 report=report,
1508                 stdin_filename=stdin_filename,
1509             )
1510         )
1511         self.assertEqual(sorted(expected), sorted(sources))
1512
1513     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1514     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1515         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1516         # file being passed directly. This is the same as
1517         # test_exclude_for_issue_1572
1518         path = THIS_DIR / "data" / "include_exclude_tests"
1519         include = ""
1520         extend_exclude = r"/exclude/|a\.py"
1521         src = "-"
1522         report = black.Report()
1523         stdin_filename = str(path / "b/exclude/a.py")
1524         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1525         sources = list(
1526             black.get_sources(
1527                 ctx=FakeContext(),
1528                 src=(src,),
1529                 quiet=True,
1530                 verbose=False,
1531                 include=re.compile(include),
1532                 exclude=re.compile(""),
1533                 extend_exclude=re.compile(extend_exclude),
1534                 force_exclude=None,
1535                 report=report,
1536                 stdin_filename=stdin_filename,
1537             )
1538         )
1539         self.assertEqual(sorted(expected), sorted(sources))
1540
1541     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1542     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1543         # Force exclude should exclude the file when passing it through
1544         # stdin_filename
1545         path = THIS_DIR / "data" / "include_exclude_tests"
1546         include = ""
1547         force_exclude = r"/exclude/|a\.py"
1548         src = "-"
1549         report = black.Report()
1550         stdin_filename = str(path / "b/exclude/a.py")
1551         sources = list(
1552             black.get_sources(
1553                 ctx=FakeContext(),
1554                 src=(src,),
1555                 quiet=True,
1556                 verbose=False,
1557                 include=re.compile(include),
1558                 exclude=re.compile(""),
1559                 extend_exclude=None,
1560                 force_exclude=re.compile(force_exclude),
1561                 report=report,
1562                 stdin_filename=stdin_filename,
1563             )
1564         )
1565         self.assertEqual([], sorted(sources))
1566
1567     def test_reformat_one_with_stdin(self) -> None:
1568         with patch(
1569             "black.format_stdin_to_stdout",
1570             return_value=lambda *args, **kwargs: black.Changed.YES,
1571         ) as fsts:
1572             report = MagicMock()
1573             path = Path("-")
1574             black.reformat_one(
1575                 path,
1576                 fast=True,
1577                 write_back=black.WriteBack.YES,
1578                 mode=DEFAULT_MODE,
1579                 report=report,
1580             )
1581             fsts.assert_called_once()
1582             report.done.assert_called_with(path, black.Changed.YES)
1583
1584     def test_reformat_one_with_stdin_filename(self) -> None:
1585         with patch(
1586             "black.format_stdin_to_stdout",
1587             return_value=lambda *args, **kwargs: black.Changed.YES,
1588         ) as fsts:
1589             report = MagicMock()
1590             p = "foo.py"
1591             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1592             expected = Path(p)
1593             black.reformat_one(
1594                 path,
1595                 fast=True,
1596                 write_back=black.WriteBack.YES,
1597                 mode=DEFAULT_MODE,
1598                 report=report,
1599             )
1600             fsts.assert_called_once_with(
1601                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1602             )
1603             # __BLACK_STDIN_FILENAME__ should have been stripped
1604             report.done.assert_called_with(expected, black.Changed.YES)
1605
1606     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1607         with patch(
1608             "black.format_stdin_to_stdout",
1609             return_value=lambda *args, **kwargs: black.Changed.YES,
1610         ) as fsts:
1611             report = MagicMock()
1612             p = "foo.pyi"
1613             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1614             expected = Path(p)
1615             black.reformat_one(
1616                 path,
1617                 fast=True,
1618                 write_back=black.WriteBack.YES,
1619                 mode=DEFAULT_MODE,
1620                 report=report,
1621             )
1622             fsts.assert_called_once_with(
1623                 fast=True,
1624                 write_back=black.WriteBack.YES,
1625                 mode=replace(DEFAULT_MODE, is_pyi=True),
1626             )
1627             # __BLACK_STDIN_FILENAME__ should have been stripped
1628             report.done.assert_called_with(expected, black.Changed.YES)
1629
1630     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1631         with patch(
1632             "black.format_stdin_to_stdout",
1633             return_value=lambda *args, **kwargs: black.Changed.YES,
1634         ) as fsts:
1635             report = MagicMock()
1636             # Even with an existing file, since we are forcing stdin, black
1637             # should output to stdout and not modify the file inplace
1638             p = Path(str(THIS_DIR / "data/collections.py"))
1639             # Make sure is_file actually returns True
1640             self.assertTrue(p.is_file())
1641             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1642             expected = Path(p)
1643             black.reformat_one(
1644                 path,
1645                 fast=True,
1646                 write_back=black.WriteBack.YES,
1647                 mode=DEFAULT_MODE,
1648                 report=report,
1649             )
1650             fsts.assert_called_once()
1651             # __BLACK_STDIN_FILENAME__ should have been stripped
1652             report.done.assert_called_with(expected, black.Changed.YES)
1653
1654     def test_reformat_one_with_stdin_empty(self) -> None:
1655         output = io.StringIO()
1656         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1657             try:
1658                 black.format_stdin_to_stdout(
1659                     fast=True,
1660                     content="",
1661                     write_back=black.WriteBack.YES,
1662                     mode=DEFAULT_MODE,
1663                 )
1664             except io.UnsupportedOperation:
1665                 pass  # StringIO does not support detach
1666             assert output.getvalue() == ""
1667
1668     def test_gitignore_exclude(self) -> None:
1669         path = THIS_DIR / "data" / "include_exclude_tests"
1670         include = re.compile(r"\.pyi?$")
1671         exclude = re.compile(r"")
1672         report = black.Report()
1673         gitignore = PathSpec.from_lines(
1674             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1675         )
1676         sources: List[Path] = []
1677         expected = [
1678             Path(path / "b/dont_exclude/a.py"),
1679             Path(path / "b/dont_exclude/a.pyi"),
1680         ]
1681         this_abs = THIS_DIR.resolve()
1682         sources.extend(
1683             black.gen_python_files(
1684                 path.iterdir(),
1685                 this_abs,
1686                 include,
1687                 exclude,
1688                 None,
1689                 None,
1690                 report,
1691                 gitignore,
1692             )
1693         )
1694         self.assertEqual(sorted(expected), sorted(sources))
1695
1696     def test_nested_gitignore(self) -> None:
1697         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1698         include = re.compile(r"\.pyi?$")
1699         exclude = re.compile(r"")
1700         root_gitignore = black.files.get_gitignore(path)
1701         report = black.Report()
1702         expected: List[Path] = [
1703             Path(path / "x.py"),
1704             Path(path / "root/b.py"),
1705             Path(path / "root/c.py"),
1706             Path(path / "root/child/c.py"),
1707         ]
1708         this_abs = THIS_DIR.resolve()
1709         sources = list(
1710             black.gen_python_files(
1711                 path.iterdir(),
1712                 this_abs,
1713                 include,
1714                 exclude,
1715                 None,
1716                 None,
1717                 report,
1718                 root_gitignore,
1719             )
1720         )
1721         self.assertEqual(sorted(expected), sorted(sources))
1722
1723     def test_empty_include(self) -> None:
1724         path = THIS_DIR / "data" / "include_exclude_tests"
1725         report = black.Report()
1726         gitignore = PathSpec.from_lines("gitwildmatch", [])
1727         empty = re.compile(r"")
1728         sources: List[Path] = []
1729         expected = [
1730             Path(path / "b/exclude/a.pie"),
1731             Path(path / "b/exclude/a.py"),
1732             Path(path / "b/exclude/a.pyi"),
1733             Path(path / "b/dont_exclude/a.pie"),
1734             Path(path / "b/dont_exclude/a.py"),
1735             Path(path / "b/dont_exclude/a.pyi"),
1736             Path(path / "b/.definitely_exclude/a.pie"),
1737             Path(path / "b/.definitely_exclude/a.py"),
1738             Path(path / "b/.definitely_exclude/a.pyi"),
1739             Path(path / ".gitignore"),
1740             Path(path / "pyproject.toml"),
1741         ]
1742         this_abs = THIS_DIR.resolve()
1743         sources.extend(
1744             black.gen_python_files(
1745                 path.iterdir(),
1746                 this_abs,
1747                 empty,
1748                 re.compile(black.DEFAULT_EXCLUDES),
1749                 None,
1750                 None,
1751                 report,
1752                 gitignore,
1753             )
1754         )
1755         self.assertEqual(sorted(expected), sorted(sources))
1756
1757     def test_extend_exclude(self) -> None:
1758         path = THIS_DIR / "data" / "include_exclude_tests"
1759         report = black.Report()
1760         gitignore = PathSpec.from_lines("gitwildmatch", [])
1761         sources: List[Path] = []
1762         expected = [
1763             Path(path / "b/exclude/a.py"),
1764             Path(path / "b/dont_exclude/a.py"),
1765         ]
1766         this_abs = THIS_DIR.resolve()
1767         sources.extend(
1768             black.gen_python_files(
1769                 path.iterdir(),
1770                 this_abs,
1771                 re.compile(black.DEFAULT_INCLUDES),
1772                 re.compile(r"\.pyi$"),
1773                 re.compile(r"\.definitely_exclude"),
1774                 None,
1775                 report,
1776                 gitignore,
1777             )
1778         )
1779         self.assertEqual(sorted(expected), sorted(sources))
1780
1781     def test_invalid_cli_regex(self) -> None:
1782         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1783             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1784
1785     def test_required_version_matches_version(self) -> None:
1786         self.invokeBlack(
1787             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1788         )
1789
1790     def test_required_version_does_not_match_version(self) -> None:
1791         self.invokeBlack(
1792             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1793         )
1794
1795     def test_preserves_line_endings(self) -> None:
1796         with TemporaryDirectory() as workspace:
1797             test_file = Path(workspace) / "test.py"
1798             for nl in ["\n", "\r\n"]:
1799                 contents = nl.join(["def f(  ):", "    pass"])
1800                 test_file.write_bytes(contents.encode())
1801                 ff(test_file, write_back=black.WriteBack.YES)
1802                 updated_contents: bytes = test_file.read_bytes()
1803                 self.assertIn(nl.encode(), updated_contents)
1804                 if nl == "\n":
1805                     self.assertNotIn(b"\r\n", updated_contents)
1806
1807     def test_preserves_line_endings_via_stdin(self) -> None:
1808         for nl in ["\n", "\r\n"]:
1809             contents = nl.join(["def f(  ):", "    pass"])
1810             runner = BlackRunner()
1811             result = runner.invoke(
1812                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1813             )
1814             self.assertEqual(result.exit_code, 0)
1815             output = result.stdout_bytes
1816             self.assertIn(nl.encode("utf8"), output)
1817             if nl == "\n":
1818                 self.assertNotIn(b"\r\n", output)
1819
1820     def test_assert_equivalent_different_asts(self) -> None:
1821         with self.assertRaises(AssertionError):
1822             black.assert_equivalent("{}", "None")
1823
1824     def test_symlink_out_of_root_directory(self) -> None:
1825         path = MagicMock()
1826         root = THIS_DIR.resolve()
1827         child = MagicMock()
1828         include = re.compile(black.DEFAULT_INCLUDES)
1829         exclude = re.compile(black.DEFAULT_EXCLUDES)
1830         report = black.Report()
1831         gitignore = PathSpec.from_lines("gitwildmatch", [])
1832         # `child` should behave like a symlink which resolved path is clearly
1833         # outside of the `root` directory.
1834         path.iterdir.return_value = [child]
1835         child.resolve.return_value = Path("/a/b/c")
1836         child.as_posix.return_value = "/a/b/c"
1837         child.is_symlink.return_value = True
1838         try:
1839             list(
1840                 black.gen_python_files(
1841                     path.iterdir(),
1842                     root,
1843                     include,
1844                     exclude,
1845                     None,
1846                     None,
1847                     report,
1848                     gitignore,
1849                 )
1850             )
1851         except ValueError as ve:
1852             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1853         path.iterdir.assert_called_once()
1854         child.resolve.assert_called_once()
1855         child.is_symlink.assert_called_once()
1856         # `child` should behave like a strange file which resolved path is clearly
1857         # outside of the `root` directory.
1858         child.is_symlink.return_value = False
1859         with self.assertRaises(ValueError):
1860             list(
1861                 black.gen_python_files(
1862                     path.iterdir(),
1863                     root,
1864                     include,
1865                     exclude,
1866                     None,
1867                     None,
1868                     report,
1869                     gitignore,
1870                 )
1871             )
1872         path.iterdir.assert_called()
1873         self.assertEqual(path.iterdir.call_count, 2)
1874         child.resolve.assert_called()
1875         self.assertEqual(child.resolve.call_count, 2)
1876         child.is_symlink.assert_called()
1877         self.assertEqual(child.is_symlink.call_count, 2)
1878
1879     def test_shhh_click(self) -> None:
1880         try:
1881             from click import _unicodefun
1882         except ModuleNotFoundError:
1883             self.skipTest("Incompatible Click version")
1884         if not hasattr(_unicodefun, "_verify_python3_env"):
1885             self.skipTest("Incompatible Click version")
1886         # First, let's see if Click is crashing with a preferred ASCII charset.
1887         with patch("locale.getpreferredencoding") as gpe:
1888             gpe.return_value = "ASCII"
1889             with self.assertRaises(RuntimeError):
1890                 _unicodefun._verify_python3_env()  # type: ignore
1891         # Now, let's silence Click...
1892         black.patch_click()
1893         # ...and confirm it's silent.
1894         with patch("locale.getpreferredencoding") as gpe:
1895             gpe.return_value = "ASCII"
1896             try:
1897                 _unicodefun._verify_python3_env()  # type: ignore
1898             except RuntimeError as re:
1899                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1900
1901     def test_root_logger_not_used_directly(self) -> None:
1902         def fail(*args: Any, **kwargs: Any) -> None:
1903             self.fail("Record created with root logger")
1904
1905         with patch.multiple(
1906             logging.root,
1907             debug=fail,
1908             info=fail,
1909             warning=fail,
1910             error=fail,
1911             critical=fail,
1912             log=fail,
1913         ):
1914             ff(THIS_DIR / "util.py")
1915
1916     def test_invalid_config_return_code(self) -> None:
1917         tmp_file = Path(black.dump_to_file())
1918         try:
1919             tmp_config = Path(black.dump_to_file())
1920             tmp_config.unlink()
1921             args = ["--config", str(tmp_config), str(tmp_file)]
1922             self.invokeBlack(args, exit_code=2, ignore_config=False)
1923         finally:
1924             tmp_file.unlink()
1925
1926     def test_parse_pyproject_toml(self) -> None:
1927         test_toml_file = THIS_DIR / "test.toml"
1928         config = black.parse_pyproject_toml(str(test_toml_file))
1929         self.assertEqual(config["verbose"], 1)
1930         self.assertEqual(config["check"], "no")
1931         self.assertEqual(config["diff"], "y")
1932         self.assertEqual(config["color"], True)
1933         self.assertEqual(config["line_length"], 79)
1934         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1935         self.assertEqual(config["exclude"], r"\.pyi?$")
1936         self.assertEqual(config["include"], r"\.py?$")
1937
1938     def test_read_pyproject_toml(self) -> None:
1939         test_toml_file = THIS_DIR / "test.toml"
1940         fake_ctx = FakeContext()
1941         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1942         config = fake_ctx.default_map
1943         self.assertEqual(config["verbose"], "1")
1944         self.assertEqual(config["check"], "no")
1945         self.assertEqual(config["diff"], "y")
1946         self.assertEqual(config["color"], "True")
1947         self.assertEqual(config["line_length"], "79")
1948         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1949         self.assertEqual(config["exclude"], r"\.pyi?$")
1950         self.assertEqual(config["include"], r"\.py?$")
1951
1952     def test_find_project_root(self) -> None:
1953         with TemporaryDirectory() as workspace:
1954             root = Path(workspace)
1955             test_dir = root / "test"
1956             test_dir.mkdir()
1957
1958             src_dir = root / "src"
1959             src_dir.mkdir()
1960
1961             root_pyproject = root / "pyproject.toml"
1962             root_pyproject.touch()
1963             src_pyproject = src_dir / "pyproject.toml"
1964             src_pyproject.touch()
1965             src_python = src_dir / "foo.py"
1966             src_python.touch()
1967
1968             self.assertEqual(
1969                 black.find_project_root((src_dir, test_dir)), root.resolve()
1970             )
1971             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1972             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1973
1974     @patch(
1975         "black.files.find_user_pyproject_toml",
1976         black.files.find_user_pyproject_toml.__wrapped__,
1977     )
1978     def test_find_user_pyproject_toml_linux(self) -> None:
1979         if system() == "Windows":
1980             return
1981
1982         # Test if XDG_CONFIG_HOME is checked
1983         with TemporaryDirectory() as workspace:
1984             tmp_user_config = Path(workspace) / "black"
1985             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1986                 self.assertEqual(
1987                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1988                 )
1989
1990         # Test fallback for XDG_CONFIG_HOME
1991         with patch.dict("os.environ"):
1992             os.environ.pop("XDG_CONFIG_HOME", None)
1993             fallback_user_config = Path("~/.config").expanduser() / "black"
1994             self.assertEqual(
1995                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1996             )
1997
1998     def test_find_user_pyproject_toml_windows(self) -> None:
1999         if system() != "Windows":
2000             return
2001
2002         user_config_path = Path.home() / ".black"
2003         self.assertEqual(
2004             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2005         )
2006
2007     def test_bpo_33660_workaround(self) -> None:
2008         if system() == "Windows":
2009             return
2010
2011         # https://bugs.python.org/issue33660
2012
2013         old_cwd = Path.cwd()
2014         try:
2015             root = Path("/")
2016             os.chdir(str(root))
2017             path = Path("workspace") / "project"
2018             report = black.Report(verbose=True)
2019             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2020             self.assertEqual(normalized_path, "workspace/project")
2021         finally:
2022             os.chdir(str(old_cwd))
2023
2024     def test_newline_comment_interaction(self) -> None:
2025         source = "class A:\\\r\n# type: ignore\n pass\n"
2026         output = black.format_str(source, mode=DEFAULT_MODE)
2027         black.assert_stable(source, output, mode=DEFAULT_MODE)
2028
2029     def test_bpo_2142_workaround(self) -> None:
2030
2031         # https://bugs.python.org/issue2142
2032
2033         source, _ = read_data("missing_final_newline.py")
2034         # read_data adds a trailing newline
2035         source = source.rstrip()
2036         expected, _ = read_data("missing_final_newline.diff")
2037         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2038         diff_header = re.compile(
2039             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2040             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2041         )
2042         try:
2043             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2044             self.assertEqual(result.exit_code, 0)
2045         finally:
2046             os.unlink(tmp_file)
2047         actual = result.output
2048         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2049         self.assertEqual(actual, expected)
2050
2051     @pytest.mark.python2
2052     def test_docstring_reformat_for_py27(self) -> None:
2053         """
2054         Check that stripping trailing whitespace from Python 2 docstrings
2055         doesn't trigger a "not equivalent to source" error
2056         """
2057         source = (
2058             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2059         )
2060         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2061
2062         result = CliRunner().invoke(
2063             black.main,
2064             ["-", "-q", "--target-version=py27"],
2065             input=BytesIO(source),
2066         )
2067
2068         self.assertEqual(result.exit_code, 0)
2069         actual = result.output
2070         self.assertFormatEqual(actual, expected)
2071
2072     @staticmethod
2073     def compare_results(
2074         result: click.testing.Result, expected_value: str, expected_exit_code: int
2075     ) -> None:
2076         """Helper method to test the value and exit code of a click Result."""
2077         assert (
2078             result.output == expected_value
2079         ), "The output did not match the expected value."
2080         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2081
2082     def test_code_option(self) -> None:
2083         """Test the code option with no changes."""
2084         code = 'print("Hello world")\n'
2085         args = ["--code", code]
2086         result = CliRunner().invoke(black.main, args)
2087
2088         self.compare_results(result, code, 0)
2089
2090     def test_code_option_changed(self) -> None:
2091         """Test the code option when changes are required."""
2092         code = "print('hello world')"
2093         formatted = black.format_str(code, mode=DEFAULT_MODE)
2094
2095         args = ["--code", code]
2096         result = CliRunner().invoke(black.main, args)
2097
2098         self.compare_results(result, formatted, 0)
2099
2100     def test_code_option_check(self) -> None:
2101         """Test the code option when check is passed."""
2102         args = ["--check", "--code", 'print("Hello world")\n']
2103         result = CliRunner().invoke(black.main, args)
2104         self.compare_results(result, "", 0)
2105
2106     def test_code_option_check_changed(self) -> None:
2107         """Test the code option when changes are required, and check is passed."""
2108         args = ["--check", "--code", "print('hello world')"]
2109         result = CliRunner().invoke(black.main, args)
2110         self.compare_results(result, "", 1)
2111
2112     def test_code_option_diff(self) -> None:
2113         """Test the code option when diff is passed."""
2114         code = "print('hello world')"
2115         formatted = black.format_str(code, mode=DEFAULT_MODE)
2116         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2117
2118         args = ["--diff", "--code", code]
2119         result = CliRunner().invoke(black.main, args)
2120
2121         # Remove time from diff
2122         output = DIFF_TIME.sub("", result.output)
2123
2124         assert output == result_diff, "The output did not match the expected value."
2125         assert result.exit_code == 0, "The exit code is incorrect."
2126
2127     def test_code_option_color_diff(self) -> None:
2128         """Test the code option when color and diff are passed."""
2129         code = "print('hello world')"
2130         formatted = black.format_str(code, mode=DEFAULT_MODE)
2131
2132         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2133         result_diff = color_diff(result_diff)
2134
2135         args = ["--diff", "--color", "--code", code]
2136         result = CliRunner().invoke(black.main, args)
2137
2138         # Remove time from diff
2139         output = DIFF_TIME.sub("", result.output)
2140
2141         assert output == result_diff, "The output did not match the expected value."
2142         assert result.exit_code == 0, "The exit code is incorrect."
2143
2144     def test_code_option_safe(self) -> None:
2145         """Test that the code option throws an error when the sanity checks fail."""
2146         # Patch black.assert_equivalent to ensure the sanity checks fail
2147         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2148             code = 'print("Hello world")'
2149             error_msg = f"{code}\nerror: cannot format <string>: \n"
2150
2151             args = ["--safe", "--code", code]
2152             result = CliRunner().invoke(black.main, args)
2153
2154             self.compare_results(result, error_msg, 123)
2155
2156     def test_code_option_fast(self) -> None:
2157         """Test that the code option ignores errors when the sanity checks fail."""
2158         # Patch black.assert_equivalent to ensure the sanity checks fail
2159         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2160             code = 'print("Hello world")'
2161             formatted = black.format_str(code, mode=DEFAULT_MODE)
2162
2163             args = ["--fast", "--code", code]
2164             result = CliRunner().invoke(black.main, args)
2165
2166             self.compare_results(result, formatted, 0)
2167
2168     def test_code_option_config(self) -> None:
2169         """
2170         Test that the code option finds the pyproject.toml in the current directory.
2171         """
2172         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2173             # Make sure we are in the project root with the pyproject file
2174             if not Path("tests").exists():
2175                 os.chdir("..")
2176
2177             args = ["--code", "print"]
2178             CliRunner().invoke(black.main, args)
2179
2180             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2181             assert (
2182                 len(parse.mock_calls) >= 1
2183             ), "Expected config parse to be called with the current directory."
2184
2185             _, call_args, _ = parse.mock_calls[0]
2186             assert (
2187                 call_args[0].lower() == str(pyproject_path).lower()
2188             ), "Incorrect config loaded."
2189
2190     def test_code_option_parent_config(self) -> None:
2191         """
2192         Test that the code option finds the pyproject.toml in the parent directory.
2193         """
2194         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2195             # Make sure we are in the tests directory
2196             if Path("tests").exists():
2197                 os.chdir("tests")
2198
2199             args = ["--code", "print"]
2200             CliRunner().invoke(black.main, args)
2201
2202             pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2203             assert (
2204                 len(parse.mock_calls) >= 1
2205             ), "Expected config parse to be called with the current directory."
2206
2207             _, call_args, _ = parse.mock_calls[0]
2208             assert (
2209                 call_args[0].lower() == str(pyproject_path).lower()
2210             ), "Incorrect config loaded."
2211
2212
2213 with open(black.__file__, "r", encoding="utf-8") as _bf:
2214     black_source_lines = _bf.readlines()
2215
2216
2217 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2218     """Show function calls `from black/__init__.py` as they happen.
2219
2220     Register this with `sys.settrace()` in a test you're debugging.
2221     """
2222     if event != "call":
2223         return tracefunc
2224
2225     stack = len(inspect.stack()) - 19
2226     stack *= 2
2227     filename = frame.f_code.co_filename
2228     lineno = frame.f_lineno
2229     func_sig_lineno = lineno - 1
2230     funcname = black_source_lines[func_sig_lineno].strip()
2231     while funcname.startswith("@"):
2232         func_sig_lineno += 1
2233         funcname = black_source_lines[func_sig_lineno].strip()
2234     if "black/__init__.py" in filename:
2235         print(f"{' ' * stack}{lineno}:{funcname}")
2236     return tracefunc
2237
2238
2239 if __name__ == "__main__":
2240     unittest.main(module="test_black")