]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

942446ec3435d03d0aa580f9d7fccbd867bf3de9
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 import io
10 from io import BytesIO
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     Callable,
21     Dict,
22     List,
23     Iterator,
24     TypeVar,
25 )
26 import pytest
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36 from black.cache import get_cache_file
37 from black.debug import DebugVisitor
38 from black.output import diff, color_diff
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     change_directory,
48     read_data,
49     DETERMINISTIC_HEADER,
50     BlackBaseTestCase,
51     DEFAULT_MODE,
52     fs,
53     ff,
54     dump_to_stderr,
55 )
56
57
58 THIS_FILE = Path(__file__)
59 PY36_VERSIONS = {
60     TargetVersion.PY36,
61     TargetVersion.PY37,
62     TargetVersion.PY38,
63     TargetVersion.PY39,
64 }
65 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
66 T = TypeVar("T")
67 R = TypeVar("R")
68
69 # Match the time output in a diff, but nothing else
70 DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
71
72
73 @contextmanager
74 def cache_dir(exists: bool = True) -> Iterator[Path]:
75     with TemporaryDirectory() as workspace:
76         cache_dir = Path(workspace)
77         if not exists:
78             cache_dir = cache_dir / "new"
79         with patch("black.cache.CACHE_DIR", cache_dir):
80             yield cache_dir
81
82
83 @contextmanager
84 def event_loop() -> Iterator[None]:
85     policy = asyncio.get_event_loop_policy()
86     loop = policy.new_event_loop()
87     asyncio.set_event_loop(loop)
88     try:
89         yield
90
91     finally:
92         loop.close()
93
94
95 class FakeContext(click.Context):
96     """A fake click Context for when calling functions that need it."""
97
98     def __init__(self) -> None:
99         self.default_map: Dict[str, Any] = {}
100
101
102 class FakeParameter(click.Parameter):
103     """A fake click Parameter for when calling functions that need it."""
104
105     def __init__(self) -> None:
106         pass
107
108
109 class BlackRunner(CliRunner):
110     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
111
112     def __init__(self) -> None:
113         super().__init__(mix_stderr=False)
114
115
116 class BlackTestCase(BlackBaseTestCase):
117     def invokeBlack(
118         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
119     ) -> None:
120         runner = BlackRunner()
121         if ignore_config:
122             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
123         result = runner.invoke(black.main, args)
124         assert result.stdout_bytes is not None
125         assert result.stderr_bytes is not None
126         self.assertEqual(
127             result.exit_code,
128             exit_code,
129             msg=(
130                 f"Failed with args: {args}\n"
131                 f"stdout: {result.stdout_bytes.decode()!r}\n"
132                 f"stderr: {result.stderr_bytes.decode()!r}\n"
133                 f"exception: {result.exception}"
134             ),
135         )
136
137     @patch("black.dump_to_file", dump_to_stderr)
138     def test_empty(self) -> None:
139         source = expected = ""
140         actual = fs(source)
141         self.assertFormatEqual(expected, actual)
142         black.assert_equivalent(source, actual)
143         black.assert_stable(source, actual, DEFAULT_MODE)
144
145     def test_empty_ff(self) -> None:
146         expected = ""
147         tmp_file = Path(black.dump_to_file())
148         try:
149             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
150             with open(tmp_file, encoding="utf8") as f:
151                 actual = f.read()
152         finally:
153             os.unlink(tmp_file)
154         self.assertFormatEqual(expected, actual)
155
156     def test_piping(self) -> None:
157         source, expected = read_data("src/black/__init__", data=False)
158         result = BlackRunner().invoke(
159             black.main,
160             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
161             input=BytesIO(source.encode("utf8")),
162         )
163         self.assertEqual(result.exit_code, 0)
164         self.assertFormatEqual(expected, result.output)
165         if source != result.output:
166             black.assert_equivalent(source, result.output)
167             black.assert_stable(source, result.output, DEFAULT_MODE)
168
169     def test_piping_diff(self) -> None:
170         diff_header = re.compile(
171             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
172             r"\+\d\d\d\d"
173         )
174         source, _ = read_data("expression.py")
175         expected, _ = read_data("expression.diff")
176         config = THIS_DIR / "data" / "empty_pyproject.toml"
177         args = [
178             "-",
179             "--fast",
180             f"--line-length={black.DEFAULT_LINE_LENGTH}",
181             "--diff",
182             f"--config={config}",
183         ]
184         result = BlackRunner().invoke(
185             black.main, args, input=BytesIO(source.encode("utf8"))
186         )
187         self.assertEqual(result.exit_code, 0)
188         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
189         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
190         self.assertEqual(expected, actual)
191
192     def test_piping_diff_with_color(self) -> None:
193         source, _ = read_data("expression.py")
194         config = THIS_DIR / "data" / "empty_pyproject.toml"
195         args = [
196             "-",
197             "--fast",
198             f"--line-length={black.DEFAULT_LINE_LENGTH}",
199             "--diff",
200             "--color",
201             f"--config={config}",
202         ]
203         result = BlackRunner().invoke(
204             black.main, args, input=BytesIO(source.encode("utf8"))
205         )
206         actual = result.output
207         # Again, the contents are checked in a different test, so only look for colors.
208         self.assertIn("\033[1;37m", actual)
209         self.assertIn("\033[36m", actual)
210         self.assertIn("\033[32m", actual)
211         self.assertIn("\033[31m", actual)
212         self.assertIn("\033[0m", actual)
213
214     @patch("black.dump_to_file", dump_to_stderr)
215     def _test_wip(self) -> None:
216         source, expected = read_data("wip")
217         sys.settrace(tracefunc)
218         mode = replace(
219             DEFAULT_MODE,
220             experimental_string_processing=False,
221             target_versions={black.TargetVersion.PY38},
222         )
223         actual = fs(source, mode=mode)
224         sys.settrace(None)
225         self.assertFormatEqual(expected, actual)
226         black.assert_equivalent(source, actual)
227         black.assert_stable(source, actual, black.FileMode())
228
229     @unittest.expectedFailure
230     @patch("black.dump_to_file", dump_to_stderr)
231     def test_trailing_comma_optional_parens_stability1(self) -> None:
232         source, _expected = read_data("trailing_comma_optional_parens1")
233         actual = fs(source)
234         black.assert_stable(source, actual, DEFAULT_MODE)
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability2(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens2")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability3(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens3")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
252         source, _expected = read_data("trailing_comma_optional_parens1")
253         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
254         black.assert_stable(source, actual, DEFAULT_MODE)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
258         source, _expected = read_data("trailing_comma_optional_parens2")
259         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
260         black.assert_stable(source, actual, DEFAULT_MODE)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
264         source, _expected = read_data("trailing_comma_optional_parens3")
265         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
266         black.assert_stable(source, actual, DEFAULT_MODE)
267
268     @patch("black.dump_to_file", dump_to_stderr)
269     def test_pep_572(self) -> None:
270         source, expected = read_data("pep_572")
271         actual = fs(source)
272         self.assertFormatEqual(expected, actual)
273         black.assert_stable(source, actual, DEFAULT_MODE)
274         if sys.version_info >= (3, 8):
275             black.assert_equivalent(source, actual)
276
277     @patch("black.dump_to_file", dump_to_stderr)
278     def test_pep_572_remove_parens(self) -> None:
279         source, expected = read_data("pep_572_remove_parens")
280         actual = fs(source)
281         self.assertFormatEqual(expected, actual)
282         black.assert_stable(source, actual, DEFAULT_MODE)
283         if sys.version_info >= (3, 8):
284             black.assert_equivalent(source, actual)
285
286     @patch("black.dump_to_file", dump_to_stderr)
287     def test_pep_572_do_not_remove_parens(self) -> None:
288         source, expected = read_data("pep_572_do_not_remove_parens")
289         # the AST safety checks will fail, but that's expected, just make sure no
290         # parentheses are touched
291         actual = black.format_str(source, mode=DEFAULT_MODE)
292         self.assertFormatEqual(expected, actual)
293
294     def test_pep_572_version_detection(self) -> None:
295         source, _ = read_data("pep_572")
296         root = black.lib2to3_parse(source)
297         features = black.get_features_used(root)
298         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
299         versions = black.detect_target_versions(root)
300         self.assertIn(black.TargetVersion.PY38, versions)
301
302     def test_expression_ff(self) -> None:
303         source, expected = read_data("expression")
304         tmp_file = Path(black.dump_to_file(source))
305         try:
306             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
307             with open(tmp_file, encoding="utf8") as f:
308                 actual = f.read()
309         finally:
310             os.unlink(tmp_file)
311         self.assertFormatEqual(expected, actual)
312         with patch("black.dump_to_file", dump_to_stderr):
313             black.assert_equivalent(source, actual)
314             black.assert_stable(source, actual, DEFAULT_MODE)
315
316     def test_expression_diff(self) -> None:
317         source, _ = read_data("expression.py")
318         config = THIS_DIR / "data" / "empty_pyproject.toml"
319         expected, _ = read_data("expression.diff")
320         tmp_file = Path(black.dump_to_file(source))
321         diff_header = re.compile(
322             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
323             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
324         )
325         try:
326             result = BlackRunner().invoke(
327                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
328             )
329             self.assertEqual(result.exit_code, 0)
330         finally:
331             os.unlink(tmp_file)
332         actual = result.output
333         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
334         if expected != actual:
335             dump = black.dump_to_file(actual)
336             msg = (
337                 "Expected diff isn't equal to the actual. If you made changes to"
338                 " expression.py and this is an anticipated difference, overwrite"
339                 f" tests/data/expression.diff with {dump}"
340             )
341             self.assertEqual(expected, actual, msg)
342
343     def test_expression_diff_with_color(self) -> None:
344         source, _ = read_data("expression.py")
345         config = THIS_DIR / "data" / "empty_pyproject.toml"
346         expected, _ = read_data("expression.diff")
347         tmp_file = Path(black.dump_to_file(source))
348         try:
349             result = BlackRunner().invoke(
350                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
351             )
352         finally:
353             os.unlink(tmp_file)
354         actual = result.output
355         # We check the contents of the diff in `test_expression_diff`. All
356         # we need to check here is that color codes exist in the result.
357         self.assertIn("\033[1;37m", actual)
358         self.assertIn("\033[36m", actual)
359         self.assertIn("\033[32m", actual)
360         self.assertIn("\033[31m", actual)
361         self.assertIn("\033[0m", actual)
362
363     @patch("black.dump_to_file", dump_to_stderr)
364     def test_pep_570(self) -> None:
365         source, expected = read_data("pep_570")
366         actual = fs(source)
367         self.assertFormatEqual(expected, actual)
368         black.assert_stable(source, actual, DEFAULT_MODE)
369         if sys.version_info >= (3, 8):
370             black.assert_equivalent(source, actual)
371
372     def test_detect_pos_only_arguments(self) -> None:
373         source, _ = read_data("pep_570")
374         root = black.lib2to3_parse(source)
375         features = black.get_features_used(root)
376         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
377         versions = black.detect_target_versions(root)
378         self.assertIn(black.TargetVersion.PY38, versions)
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_string_quotes(self) -> None:
382         source, expected = read_data("string_quotes")
383         mode = black.Mode(experimental_string_processing=True)
384         actual = fs(source, mode=mode)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, mode)
388         mode = replace(mode, string_normalization=False)
389         not_normalized = fs(source, mode=mode)
390         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
391         black.assert_equivalent(source, not_normalized)
392         black.assert_stable(source, not_normalized, mode=mode)
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_docstring_no_string_normalization(self) -> None:
396         """Like test_docstring but with string normalization off."""
397         source, expected = read_data("docstring_no_string_normalization")
398         mode = replace(DEFAULT_MODE, string_normalization=False)
399         actual = fs(source, mode=mode)
400         self.assertFormatEqual(expected, actual)
401         black.assert_equivalent(source, actual)
402         black.assert_stable(source, actual, mode)
403
404     def test_long_strings_flag_disabled(self) -> None:
405         """Tests for turning off the string processing logic."""
406         source, expected = read_data("long_strings_flag_disabled")
407         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
408         actual = fs(source, mode=mode)
409         self.assertFormatEqual(expected, actual)
410         black.assert_stable(expected, actual, mode)
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_numeric_literals(self) -> None:
414         source, expected = read_data("numeric_literals")
415         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
416         actual = fs(source, mode=mode)
417         self.assertFormatEqual(expected, actual)
418         black.assert_equivalent(source, actual)
419         black.assert_stable(source, actual, mode)
420
421     @patch("black.dump_to_file", dump_to_stderr)
422     def test_numeric_literals_ignoring_underscores(self) -> None:
423         source, expected = read_data("numeric_literals_skip_underscores")
424         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
425         actual = fs(source, mode=mode)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, mode)
429
430     def test_skip_magic_trailing_comma(self) -> None:
431         source, _ = read_data("expression.py")
432         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
433         tmp_file = Path(black.dump_to_file(source))
434         diff_header = re.compile(
435             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
436             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
437         )
438         try:
439             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
440             self.assertEqual(result.exit_code, 0)
441         finally:
442             os.unlink(tmp_file)
443         actual = result.output
444         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
445         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
446         if expected != actual:
447             dump = black.dump_to_file(actual)
448             msg = (
449                 "Expected diff isn't equal to the actual. If you made changes to"
450                 " expression.py and this is an anticipated difference, overwrite"
451                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
452             )
453             self.assertEqual(expected, actual, msg)
454
455     @pytest.mark.python2
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2_print_function(self) -> None:
458         source, expected = read_data("python2_print_function")
459         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
460         actual = fs(source, mode=mode)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, mode)
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_stub(self) -> None:
467         mode = replace(DEFAULT_MODE, is_pyi=True)
468         source, expected = read_data("stub.pyi")
469         actual = fs(source, mode=mode)
470         self.assertFormatEqual(expected, actual)
471         black.assert_stable(source, actual, mode)
472
473     @patch("black.dump_to_file", dump_to_stderr)
474     def test_async_as_identifier(self) -> None:
475         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
476         source, expected = read_data("async_as_identifier")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         major, minor = sys.version_info[:2]
480         if major < 3 or (major <= 3 and minor < 7):
481             black.assert_equivalent(source, actual)
482         black.assert_stable(source, actual, DEFAULT_MODE)
483         # ensure black can parse this when the target is 3.6
484         self.invokeBlack([str(source_path), "--target-version", "py36"])
485         # but not on 3.7, because async/await is no longer an identifier
486         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_python37(self) -> None:
490         source_path = (THIS_DIR / "data" / "python37.py").resolve()
491         source, expected = read_data("python37")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         major, minor = sys.version_info[:2]
495         if major > 3 or (major == 3 and minor >= 7):
496             black.assert_equivalent(source, actual)
497         black.assert_stable(source, actual, DEFAULT_MODE)
498         # ensure black can parse this when the target is 3.7
499         self.invokeBlack([str(source_path), "--target-version", "py37"])
500         # but not on 3.6, because we use async as a reserved keyword
501         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
502
503     @patch("black.dump_to_file", dump_to_stderr)
504     def test_python38(self) -> None:
505         source, expected = read_data("python38")
506         actual = fs(source)
507         self.assertFormatEqual(expected, actual)
508         major, minor = sys.version_info[:2]
509         if major > 3 or (major == 3 and minor >= 8):
510             black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, DEFAULT_MODE)
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_python39(self) -> None:
515         source, expected = read_data("python39")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         major, minor = sys.version_info[:2]
519         if major > 3 or (major == 3 and minor >= 9):
520             black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, DEFAULT_MODE)
522
523     def test_tab_comment_indentation(self) -> None:
524         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
525         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
526         self.assertFormatEqual(contents_spc, fs(contents_spc))
527         self.assertFormatEqual(contents_spc, fs(contents_tab))
528
529         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
530         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
531         self.assertFormatEqual(contents_spc, fs(contents_spc))
532         self.assertFormatEqual(contents_spc, fs(contents_tab))
533
534         # mixed tabs and spaces (valid Python 2 code)
535         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
536         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
537         self.assertFormatEqual(contents_spc, fs(contents_spc))
538         self.assertFormatEqual(contents_spc, fs(contents_tab))
539
540         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
541         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
542         self.assertFormatEqual(contents_spc, fs(contents_spc))
543         self.assertFormatEqual(contents_spc, fs(contents_tab))
544
545     def test_report_verbose(self) -> None:
546         report = Report(verbose=True)
547         out_lines = []
548         err_lines = []
549
550         def out(msg: str, **kwargs: Any) -> None:
551             out_lines.append(msg)
552
553         def err(msg: str, **kwargs: Any) -> None:
554             err_lines.append(msg)
555
556         with patch("black.output._out", out), patch("black.output._err", err):
557             report.done(Path("f1"), black.Changed.NO)
558             self.assertEqual(len(out_lines), 1)
559             self.assertEqual(len(err_lines), 0)
560             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
561             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
562             self.assertEqual(report.return_code, 0)
563             report.done(Path("f2"), black.Changed.YES)
564             self.assertEqual(len(out_lines), 2)
565             self.assertEqual(len(err_lines), 0)
566             self.assertEqual(out_lines[-1], "reformatted f2")
567             self.assertEqual(
568                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
569             )
570             report.done(Path("f3"), black.Changed.CACHED)
571             self.assertEqual(len(out_lines), 3)
572             self.assertEqual(len(err_lines), 0)
573             self.assertEqual(
574                 out_lines[-1], "f3 wasn't modified on disk since last run."
575             )
576             self.assertEqual(
577                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
578             )
579             self.assertEqual(report.return_code, 0)
580             report.check = True
581             self.assertEqual(report.return_code, 1)
582             report.check = False
583             report.failed(Path("e1"), "boom")
584             self.assertEqual(len(out_lines), 3)
585             self.assertEqual(len(err_lines), 1)
586             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
587             self.assertEqual(
588                 unstyle(str(report)),
589                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
590                 " reformat.",
591             )
592             self.assertEqual(report.return_code, 123)
593             report.done(Path("f3"), black.Changed.YES)
594             self.assertEqual(len(out_lines), 4)
595             self.assertEqual(len(err_lines), 1)
596             self.assertEqual(out_lines[-1], "reformatted f3")
597             self.assertEqual(
598                 unstyle(str(report)),
599                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
600                 " reformat.",
601             )
602             self.assertEqual(report.return_code, 123)
603             report.failed(Path("e2"), "boom")
604             self.assertEqual(len(out_lines), 4)
605             self.assertEqual(len(err_lines), 2)
606             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
607             self.assertEqual(
608                 unstyle(str(report)),
609                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
610                 " reformat.",
611             )
612             self.assertEqual(report.return_code, 123)
613             report.path_ignored(Path("wat"), "no match")
614             self.assertEqual(len(out_lines), 5)
615             self.assertEqual(len(err_lines), 2)
616             self.assertEqual(out_lines[-1], "wat ignored: no match")
617             self.assertEqual(
618                 unstyle(str(report)),
619                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
620                 " reformat.",
621             )
622             self.assertEqual(report.return_code, 123)
623             report.done(Path("f4"), black.Changed.NO)
624             self.assertEqual(len(out_lines), 6)
625             self.assertEqual(len(err_lines), 2)
626             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
627             self.assertEqual(
628                 unstyle(str(report)),
629                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
630                 " reformat.",
631             )
632             self.assertEqual(report.return_code, 123)
633             report.check = True
634             self.assertEqual(
635                 unstyle(str(report)),
636                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
637                 " would fail to reformat.",
638             )
639             report.check = False
640             report.diff = True
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
644                 " would fail to reformat.",
645             )
646
647     def test_report_quiet(self) -> None:
648         report = Report(quiet=True)
649         out_lines = []
650         err_lines = []
651
652         def out(msg: str, **kwargs: Any) -> None:
653             out_lines.append(msg)
654
655         def err(msg: str, **kwargs: Any) -> None:
656             err_lines.append(msg)
657
658         with patch("black.output._out", out), patch("black.output._err", err):
659             report.done(Path("f1"), black.Changed.NO)
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 0)
662             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
663             self.assertEqual(report.return_code, 0)
664             report.done(Path("f2"), black.Changed.YES)
665             self.assertEqual(len(out_lines), 0)
666             self.assertEqual(len(err_lines), 0)
667             self.assertEqual(
668                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
669             )
670             report.done(Path("f3"), black.Changed.CACHED)
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 0)
673             self.assertEqual(
674                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
675             )
676             self.assertEqual(report.return_code, 0)
677             report.check = True
678             self.assertEqual(report.return_code, 1)
679             report.check = False
680             report.failed(Path("e1"), "boom")
681             self.assertEqual(len(out_lines), 0)
682             self.assertEqual(len(err_lines), 1)
683             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
684             self.assertEqual(
685                 unstyle(str(report)),
686                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
687                 " reformat.",
688             )
689             self.assertEqual(report.return_code, 123)
690             report.done(Path("f3"), black.Changed.YES)
691             self.assertEqual(len(out_lines), 0)
692             self.assertEqual(len(err_lines), 1)
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
696                 " reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.failed(Path("e2"), "boom")
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 2)
702             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
703             self.assertEqual(
704                 unstyle(str(report)),
705                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
706                 " reformat.",
707             )
708             self.assertEqual(report.return_code, 123)
709             report.path_ignored(Path("wat"), "no match")
710             self.assertEqual(len(out_lines), 0)
711             self.assertEqual(len(err_lines), 2)
712             self.assertEqual(
713                 unstyle(str(report)),
714                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
715                 " reformat.",
716             )
717             self.assertEqual(report.return_code, 123)
718             report.done(Path("f4"), black.Changed.NO)
719             self.assertEqual(len(out_lines), 0)
720             self.assertEqual(len(err_lines), 2)
721             self.assertEqual(
722                 unstyle(str(report)),
723                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
724                 " reformat.",
725             )
726             self.assertEqual(report.return_code, 123)
727             report.check = True
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
731                 " would fail to reformat.",
732             )
733             report.check = False
734             report.diff = True
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
738                 " would fail to reformat.",
739             )
740
741     def test_report_normal(self) -> None:
742         report = black.Report()
743         out_lines = []
744         err_lines = []
745
746         def out(msg: str, **kwargs: Any) -> None:
747             out_lines.append(msg)
748
749         def err(msg: str, **kwargs: Any) -> None:
750             err_lines.append(msg)
751
752         with patch("black.output._out", out), patch("black.output._err", err):
753             report.done(Path("f1"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
757             self.assertEqual(report.return_code, 0)
758             report.done(Path("f2"), black.Changed.YES)
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 0)
761             self.assertEqual(out_lines[-1], "reformatted f2")
762             self.assertEqual(
763                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
764             )
765             report.done(Path("f3"), black.Changed.CACHED)
766             self.assertEqual(len(out_lines), 1)
767             self.assertEqual(len(err_lines), 0)
768             self.assertEqual(out_lines[-1], "reformatted f2")
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
771             )
772             self.assertEqual(report.return_code, 0)
773             report.check = True
774             self.assertEqual(report.return_code, 1)
775             report.check = False
776             report.failed(Path("e1"), "boom")
777             self.assertEqual(len(out_lines), 1)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
783                 " reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.done(Path("f3"), black.Changed.YES)
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 1)
789             self.assertEqual(out_lines[-1], "reformatted f3")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
793                 " reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.failed(Path("e2"), "boom")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
803                 " reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.path_ignored(Path("wat"), "no match")
807             self.assertEqual(len(out_lines), 2)
808             self.assertEqual(len(err_lines), 2)
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
812                 " reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.done(Path("f4"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
821                 " reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.check = True
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
828                 " would fail to reformat.",
829             )
830             report.check = False
831             report.diff = True
832             self.assertEqual(
833                 unstyle(str(report)),
834                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
835                 " would fail to reformat.",
836             )
837
838     def test_lib2to3_parse(self) -> None:
839         with self.assertRaises(black.InvalidInput):
840             black.lib2to3_parse("invalid syntax")
841
842         straddling = "x + y"
843         black.lib2to3_parse(straddling)
844         black.lib2to3_parse(straddling, {TargetVersion.PY27})
845         black.lib2to3_parse(straddling, {TargetVersion.PY36})
846         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
847
848         py2_only = "print x"
849         black.lib2to3_parse(py2_only)
850         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
851         with self.assertRaises(black.InvalidInput):
852             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
853         with self.assertRaises(black.InvalidInput):
854             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
855
856         py3_only = "exec(x, end=y)"
857         black.lib2to3_parse(py3_only)
858         with self.assertRaises(black.InvalidInput):
859             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
860         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
861         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
862
863     def test_get_features_used_decorator(self) -> None:
864         # Test the feature detection of new decorator syntax
865         # since this makes some test cases of test_get_features_used()
866         # fails if it fails, this is tested first so that a useful case
867         # is identified
868         simples, relaxed = read_data("decorators")
869         # skip explanation comments at the top of the file
870         for simple_test in simples.split("##")[1:]:
871             node = black.lib2to3_parse(simple_test)
872             decorator = str(node.children[0].children[0]).strip()
873             self.assertNotIn(
874                 Feature.RELAXED_DECORATORS,
875                 black.get_features_used(node),
876                 msg=(
877                     f"decorator '{decorator}' follows python<=3.8 syntax"
878                     "but is detected as 3.9+"
879                     # f"The full node is\n{node!r}"
880                 ),
881             )
882         # skip the '# output' comment at the top of the output part
883         for relaxed_test in relaxed.split("##")[1:]:
884             node = black.lib2to3_parse(relaxed_test)
885             decorator = str(node.children[0].children[0]).strip()
886             self.assertIn(
887                 Feature.RELAXED_DECORATORS,
888                 black.get_features_used(node),
889                 msg=(
890                     f"decorator '{decorator}' uses python3.9+ syntax"
891                     "but is detected as python<=3.8"
892                     # f"The full node is\n{node!r}"
893                 ),
894             )
895
896     def test_get_features_used(self) -> None:
897         node = black.lib2to3_parse("def f(*, arg): ...\n")
898         self.assertEqual(black.get_features_used(node), set())
899         node = black.lib2to3_parse("def f(*, arg,): ...\n")
900         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
901         node = black.lib2to3_parse("f(*arg,)\n")
902         self.assertEqual(
903             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
904         )
905         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
906         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
907         node = black.lib2to3_parse("123_456\n")
908         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
909         node = black.lib2to3_parse("123456\n")
910         self.assertEqual(black.get_features_used(node), set())
911         source, expected = read_data("function")
912         node = black.lib2to3_parse(source)
913         expected_features = {
914             Feature.TRAILING_COMMA_IN_CALL,
915             Feature.TRAILING_COMMA_IN_DEF,
916             Feature.F_STRINGS,
917         }
918         self.assertEqual(black.get_features_used(node), expected_features)
919         node = black.lib2to3_parse(expected)
920         self.assertEqual(black.get_features_used(node), expected_features)
921         source, expected = read_data("expression")
922         node = black.lib2to3_parse(source)
923         self.assertEqual(black.get_features_used(node), set())
924         node = black.lib2to3_parse(expected)
925         self.assertEqual(black.get_features_used(node), set())
926
927     def test_get_future_imports(self) -> None:
928         node = black.lib2to3_parse("\n")
929         self.assertEqual(set(), black.get_future_imports(node))
930         node = black.lib2to3_parse("from __future__ import black\n")
931         self.assertEqual({"black"}, black.get_future_imports(node))
932         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
933         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
934         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
935         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
936         node = black.lib2to3_parse(
937             "from __future__ import multiple\nfrom __future__ import imports\n"
938         )
939         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
940         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
941         self.assertEqual({"black"}, black.get_future_imports(node))
942         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
943         self.assertEqual({"black"}, black.get_future_imports(node))
944         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
945         self.assertEqual(set(), black.get_future_imports(node))
946         node = black.lib2to3_parse("from some.module import black\n")
947         self.assertEqual(set(), black.get_future_imports(node))
948         node = black.lib2to3_parse(
949             "from __future__ import unicode_literals as _unicode_literals"
950         )
951         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
952         node = black.lib2to3_parse(
953             "from __future__ import unicode_literals as _lol, print"
954         )
955         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
956
957     def test_debug_visitor(self) -> None:
958         source, _ = read_data("debug_visitor.py")
959         expected, _ = read_data("debug_visitor.out")
960         out_lines = []
961         err_lines = []
962
963         def out(msg: str, **kwargs: Any) -> None:
964             out_lines.append(msg)
965
966         def err(msg: str, **kwargs: Any) -> None:
967             err_lines.append(msg)
968
969         with patch("black.debug.out", out):
970             DebugVisitor.show(source)
971         actual = "\n".join(out_lines) + "\n"
972         log_name = ""
973         if expected != actual:
974             log_name = black.dump_to_file(*out_lines)
975         self.assertEqual(
976             expected,
977             actual,
978             f"AST print out is different. Actual version dumped to {log_name}",
979         )
980
981     def test_format_file_contents(self) -> None:
982         empty = ""
983         mode = DEFAULT_MODE
984         with self.assertRaises(black.NothingChanged):
985             black.format_file_contents(empty, mode=mode, fast=False)
986         just_nl = "\n"
987         with self.assertRaises(black.NothingChanged):
988             black.format_file_contents(just_nl, mode=mode, fast=False)
989         same = "j = [1, 2, 3]\n"
990         with self.assertRaises(black.NothingChanged):
991             black.format_file_contents(same, mode=mode, fast=False)
992         different = "j = [1,2,3]"
993         expected = same
994         actual = black.format_file_contents(different, mode=mode, fast=False)
995         self.assertEqual(expected, actual)
996         invalid = "return if you can"
997         with self.assertRaises(black.InvalidInput) as e:
998             black.format_file_contents(invalid, mode=mode, fast=False)
999         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1000
1001     def test_endmarker(self) -> None:
1002         n = black.lib2to3_parse("\n")
1003         self.assertEqual(n.type, black.syms.file_input)
1004         self.assertEqual(len(n.children), 1)
1005         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1006
1007     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1008     def test_assertFormatEqual(self) -> None:
1009         out_lines = []
1010         err_lines = []
1011
1012         def out(msg: str, **kwargs: Any) -> None:
1013             out_lines.append(msg)
1014
1015         def err(msg: str, **kwargs: Any) -> None:
1016             err_lines.append(msg)
1017
1018         with patch("black.output._out", out), patch("black.output._err", err):
1019             with self.assertRaises(AssertionError):
1020                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1021
1022         out_str = "".join(out_lines)
1023         self.assertTrue("Expected tree:" in out_str)
1024         self.assertTrue("Actual tree:" in out_str)
1025         self.assertEqual("".join(err_lines), "")
1026
1027     def test_cache_broken_file(self) -> None:
1028         mode = DEFAULT_MODE
1029         with cache_dir() as workspace:
1030             cache_file = get_cache_file(mode)
1031             with cache_file.open("w") as fobj:
1032                 fobj.write("this is not a pickle")
1033             self.assertEqual(black.read_cache(mode), {})
1034             src = (workspace / "test.py").resolve()
1035             with src.open("w") as fobj:
1036                 fobj.write("print('hello')")
1037             self.invokeBlack([str(src)])
1038             cache = black.read_cache(mode)
1039             self.assertIn(str(src), cache)
1040
1041     def test_cache_single_file_already_cached(self) -> None:
1042         mode = DEFAULT_MODE
1043         with cache_dir() as workspace:
1044             src = (workspace / "test.py").resolve()
1045             with src.open("w") as fobj:
1046                 fobj.write("print('hello')")
1047             black.write_cache({}, [src], mode)
1048             self.invokeBlack([str(src)])
1049             with src.open("r") as fobj:
1050                 self.assertEqual(fobj.read(), "print('hello')")
1051
1052     @event_loop()
1053     def test_cache_multiple_files(self) -> None:
1054         mode = DEFAULT_MODE
1055         with cache_dir() as workspace, patch(
1056             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1057         ):
1058             one = (workspace / "one.py").resolve()
1059             with one.open("w") as fobj:
1060                 fobj.write("print('hello')")
1061             two = (workspace / "two.py").resolve()
1062             with two.open("w") as fobj:
1063                 fobj.write("print('hello')")
1064             black.write_cache({}, [one], mode)
1065             self.invokeBlack([str(workspace)])
1066             with one.open("r") as fobj:
1067                 self.assertEqual(fobj.read(), "print('hello')")
1068             with two.open("r") as fobj:
1069                 self.assertEqual(fobj.read(), 'print("hello")\n')
1070             cache = black.read_cache(mode)
1071             self.assertIn(str(one), cache)
1072             self.assertIn(str(two), cache)
1073
1074     def test_no_cache_when_writeback_diff(self) -> None:
1075         mode = DEFAULT_MODE
1076         with cache_dir() as workspace:
1077             src = (workspace / "test.py").resolve()
1078             with src.open("w") as fobj:
1079                 fobj.write("print('hello')")
1080             with patch("black.read_cache") as read_cache, patch(
1081                 "black.write_cache"
1082             ) as write_cache:
1083                 self.invokeBlack([str(src), "--diff"])
1084                 cache_file = get_cache_file(mode)
1085                 self.assertFalse(cache_file.exists())
1086                 write_cache.assert_not_called()
1087                 read_cache.assert_not_called()
1088
1089     def test_no_cache_when_writeback_color_diff(self) -> None:
1090         mode = DEFAULT_MODE
1091         with cache_dir() as workspace:
1092             src = (workspace / "test.py").resolve()
1093             with src.open("w") as fobj:
1094                 fobj.write("print('hello')")
1095             with patch("black.read_cache") as read_cache, patch(
1096                 "black.write_cache"
1097             ) as write_cache:
1098                 self.invokeBlack([str(src), "--diff", "--color"])
1099                 cache_file = get_cache_file(mode)
1100                 self.assertFalse(cache_file.exists())
1101                 write_cache.assert_not_called()
1102                 read_cache.assert_not_called()
1103
1104     @event_loop()
1105     def test_output_locking_when_writeback_diff(self) -> None:
1106         with cache_dir() as workspace:
1107             for tag in range(0, 4):
1108                 src = (workspace / f"test{tag}.py").resolve()
1109                 with src.open("w") as fobj:
1110                     fobj.write("print('hello')")
1111             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1112                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1113                 # this isn't quite doing what we want, but if it _isn't_
1114                 # called then we cannot be using the lock it provides
1115                 mgr.assert_called()
1116
1117     @event_loop()
1118     def test_output_locking_when_writeback_color_diff(self) -> None:
1119         with cache_dir() as workspace:
1120             for tag in range(0, 4):
1121                 src = (workspace / f"test{tag}.py").resolve()
1122                 with src.open("w") as fobj:
1123                     fobj.write("print('hello')")
1124             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1125                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1126                 # this isn't quite doing what we want, but if it _isn't_
1127                 # called then we cannot be using the lock it provides
1128                 mgr.assert_called()
1129
1130     def test_no_cache_when_stdin(self) -> None:
1131         mode = DEFAULT_MODE
1132         with cache_dir():
1133             result = CliRunner().invoke(
1134                 black.main, ["-"], input=BytesIO(b"print('hello')")
1135             )
1136             self.assertEqual(result.exit_code, 0)
1137             cache_file = get_cache_file(mode)
1138             self.assertFalse(cache_file.exists())
1139
1140     def test_read_cache_no_cachefile(self) -> None:
1141         mode = DEFAULT_MODE
1142         with cache_dir():
1143             self.assertEqual(black.read_cache(mode), {})
1144
1145     def test_write_cache_read_cache(self) -> None:
1146         mode = DEFAULT_MODE
1147         with cache_dir() as workspace:
1148             src = (workspace / "test.py").resolve()
1149             src.touch()
1150             black.write_cache({}, [src], mode)
1151             cache = black.read_cache(mode)
1152             self.assertIn(str(src), cache)
1153             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1154
1155     def test_filter_cached(self) -> None:
1156         with TemporaryDirectory() as workspace:
1157             path = Path(workspace)
1158             uncached = (path / "uncached").resolve()
1159             cached = (path / "cached").resolve()
1160             cached_but_changed = (path / "changed").resolve()
1161             uncached.touch()
1162             cached.touch()
1163             cached_but_changed.touch()
1164             cache = {
1165                 str(cached): black.get_cache_info(cached),
1166                 str(cached_but_changed): (0.0, 0),
1167             }
1168             todo, done = black.filter_cached(
1169                 cache, {uncached, cached, cached_but_changed}
1170             )
1171             self.assertEqual(todo, {uncached, cached_but_changed})
1172             self.assertEqual(done, {cached})
1173
1174     def test_write_cache_creates_directory_if_needed(self) -> None:
1175         mode = DEFAULT_MODE
1176         with cache_dir(exists=False) as workspace:
1177             self.assertFalse(workspace.exists())
1178             black.write_cache({}, [], mode)
1179             self.assertTrue(workspace.exists())
1180
1181     @event_loop()
1182     def test_failed_formatting_does_not_get_cached(self) -> None:
1183         mode = DEFAULT_MODE
1184         with cache_dir() as workspace, patch(
1185             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1186         ):
1187             failing = (workspace / "failing.py").resolve()
1188             with failing.open("w") as fobj:
1189                 fobj.write("not actually python")
1190             clean = (workspace / "clean.py").resolve()
1191             with clean.open("w") as fobj:
1192                 fobj.write('print("hello")\n')
1193             self.invokeBlack([str(workspace)], exit_code=123)
1194             cache = black.read_cache(mode)
1195             self.assertNotIn(str(failing), cache)
1196             self.assertIn(str(clean), cache)
1197
1198     def test_write_cache_write_fail(self) -> None:
1199         mode = DEFAULT_MODE
1200         with cache_dir(), patch.object(Path, "open") as mock:
1201             mock.side_effect = OSError
1202             black.write_cache({}, [], mode)
1203
1204     @event_loop()
1205     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1206     def test_works_in_mono_process_only_environment(self) -> None:
1207         with cache_dir() as workspace:
1208             for f in [
1209                 (workspace / "one.py").resolve(),
1210                 (workspace / "two.py").resolve(),
1211             ]:
1212                 f.write_text('print("hello")\n')
1213             self.invokeBlack([str(workspace)])
1214
1215     @event_loop()
1216     def test_check_diff_use_together(self) -> None:
1217         with cache_dir():
1218             # Files which will be reformatted.
1219             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1220             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1221             # Files which will not be reformatted.
1222             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1223             self.invokeBlack([str(src2), "--diff", "--check"])
1224             # Multi file command.
1225             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1226
1227     def test_no_files(self) -> None:
1228         with cache_dir():
1229             # Without an argument, black exits with error code 0.
1230             self.invokeBlack([])
1231
1232     def test_broken_symlink(self) -> None:
1233         with cache_dir() as workspace:
1234             symlink = workspace / "broken_link.py"
1235             try:
1236                 symlink.symlink_to("nonexistent.py")
1237             except OSError as e:
1238                 self.skipTest(f"Can't create symlinks: {e}")
1239             self.invokeBlack([str(workspace.resolve())])
1240
1241     def test_read_cache_line_lengths(self) -> None:
1242         mode = DEFAULT_MODE
1243         short_mode = replace(DEFAULT_MODE, line_length=1)
1244         with cache_dir() as workspace:
1245             path = (workspace / "file.py").resolve()
1246             path.touch()
1247             black.write_cache({}, [path], mode)
1248             one = black.read_cache(mode)
1249             self.assertIn(str(path), one)
1250             two = black.read_cache(short_mode)
1251             self.assertNotIn(str(path), two)
1252
1253     def test_single_file_force_pyi(self) -> None:
1254         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1255         contents, expected = read_data("force_pyi")
1256         with cache_dir() as workspace:
1257             path = (workspace / "file.py").resolve()
1258             with open(path, "w") as fh:
1259                 fh.write(contents)
1260             self.invokeBlack([str(path), "--pyi"])
1261             with open(path, "r") as fh:
1262                 actual = fh.read()
1263             # verify cache with --pyi is separate
1264             pyi_cache = black.read_cache(pyi_mode)
1265             self.assertIn(str(path), pyi_cache)
1266             normal_cache = black.read_cache(DEFAULT_MODE)
1267             self.assertNotIn(str(path), normal_cache)
1268         self.assertFormatEqual(expected, actual)
1269         black.assert_equivalent(contents, actual)
1270         black.assert_stable(contents, actual, pyi_mode)
1271
1272     @event_loop()
1273     def test_multi_file_force_pyi(self) -> None:
1274         reg_mode = DEFAULT_MODE
1275         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1276         contents, expected = read_data("force_pyi")
1277         with cache_dir() as workspace:
1278             paths = [
1279                 (workspace / "file1.py").resolve(),
1280                 (workspace / "file2.py").resolve(),
1281             ]
1282             for path in paths:
1283                 with open(path, "w") as fh:
1284                     fh.write(contents)
1285             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1286             for path in paths:
1287                 with open(path, "r") as fh:
1288                     actual = fh.read()
1289                 self.assertEqual(actual, expected)
1290             # verify cache with --pyi is separate
1291             pyi_cache = black.read_cache(pyi_mode)
1292             normal_cache = black.read_cache(reg_mode)
1293             for path in paths:
1294                 self.assertIn(str(path), pyi_cache)
1295                 self.assertNotIn(str(path), normal_cache)
1296
1297     def test_pipe_force_pyi(self) -> None:
1298         source, expected = read_data("force_pyi")
1299         result = CliRunner().invoke(
1300             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1301         )
1302         self.assertEqual(result.exit_code, 0)
1303         actual = result.output
1304         self.assertFormatEqual(actual, expected)
1305
1306     def test_single_file_force_py36(self) -> None:
1307         reg_mode = DEFAULT_MODE
1308         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1309         source, expected = read_data("force_py36")
1310         with cache_dir() as workspace:
1311             path = (workspace / "file.py").resolve()
1312             with open(path, "w") as fh:
1313                 fh.write(source)
1314             self.invokeBlack([str(path), *PY36_ARGS])
1315             with open(path, "r") as fh:
1316                 actual = fh.read()
1317             # verify cache with --target-version is separate
1318             py36_cache = black.read_cache(py36_mode)
1319             self.assertIn(str(path), py36_cache)
1320             normal_cache = black.read_cache(reg_mode)
1321             self.assertNotIn(str(path), normal_cache)
1322         self.assertEqual(actual, expected)
1323
1324     @event_loop()
1325     def test_multi_file_force_py36(self) -> None:
1326         reg_mode = DEFAULT_MODE
1327         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1328         source, expected = read_data("force_py36")
1329         with cache_dir() as workspace:
1330             paths = [
1331                 (workspace / "file1.py").resolve(),
1332                 (workspace / "file2.py").resolve(),
1333             ]
1334             for path in paths:
1335                 with open(path, "w") as fh:
1336                     fh.write(source)
1337             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1338             for path in paths:
1339                 with open(path, "r") as fh:
1340                     actual = fh.read()
1341                 self.assertEqual(actual, expected)
1342             # verify cache with --target-version is separate
1343             pyi_cache = black.read_cache(py36_mode)
1344             normal_cache = black.read_cache(reg_mode)
1345             for path in paths:
1346                 self.assertIn(str(path), pyi_cache)
1347                 self.assertNotIn(str(path), normal_cache)
1348
1349     def test_pipe_force_py36(self) -> None:
1350         source, expected = read_data("force_py36")
1351         result = CliRunner().invoke(
1352             black.main,
1353             ["-", "-q", "--target-version=py36"],
1354             input=BytesIO(source.encode("utf8")),
1355         )
1356         self.assertEqual(result.exit_code, 0)
1357         actual = result.output
1358         self.assertFormatEqual(actual, expected)
1359
1360     def test_include_exclude(self) -> None:
1361         path = THIS_DIR / "data" / "include_exclude_tests"
1362         include = re.compile(r"\.pyi?$")
1363         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1364         report = black.Report()
1365         gitignore = PathSpec.from_lines("gitwildmatch", [])
1366         sources: List[Path] = []
1367         expected = [
1368             Path(path / "b/dont_exclude/a.py"),
1369             Path(path / "b/dont_exclude/a.pyi"),
1370         ]
1371         this_abs = THIS_DIR.resolve()
1372         sources.extend(
1373             black.gen_python_files(
1374                 path.iterdir(),
1375                 this_abs,
1376                 include,
1377                 exclude,
1378                 None,
1379                 None,
1380                 report,
1381                 gitignore,
1382                 verbose=False,
1383                 quiet=False,
1384             )
1385         )
1386         self.assertEqual(sorted(expected), sorted(sources))
1387
1388     def test_gitignore_used_as_default(self) -> None:
1389         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1390         include = re.compile(r"\.pyi?$")
1391         extend_exclude = re.compile(r"/exclude/")
1392         src = str(path / "b/")
1393         report = black.Report()
1394         expected: List[Path] = [
1395             path / "b/.definitely_exclude/a.py",
1396             path / "b/.definitely_exclude/a.pyi",
1397         ]
1398         sources = list(
1399             black.get_sources(
1400                 ctx=FakeContext(),
1401                 src=(src,),
1402                 quiet=True,
1403                 verbose=False,
1404                 include=include,
1405                 exclude=None,
1406                 extend_exclude=extend_exclude,
1407                 force_exclude=None,
1408                 report=report,
1409                 stdin_filename=None,
1410             )
1411         )
1412         self.assertEqual(sorted(expected), sorted(sources))
1413
1414     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1415     def test_exclude_for_issue_1572(self) -> None:
1416         # Exclude shouldn't touch files that were explicitly given to Black through the
1417         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1418         # https://github.com/psf/black/issues/1572
1419         path = THIS_DIR / "data" / "include_exclude_tests"
1420         include = ""
1421         exclude = r"/exclude/|a\.py"
1422         src = str(path / "b/exclude/a.py")
1423         report = black.Report()
1424         expected = [Path(path / "b/exclude/a.py")]
1425         sources = list(
1426             black.get_sources(
1427                 ctx=FakeContext(),
1428                 src=(src,),
1429                 quiet=True,
1430                 verbose=False,
1431                 include=re.compile(include),
1432                 exclude=re.compile(exclude),
1433                 extend_exclude=None,
1434                 force_exclude=None,
1435                 report=report,
1436                 stdin_filename=None,
1437             )
1438         )
1439         self.assertEqual(sorted(expected), sorted(sources))
1440
1441     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1442     def test_get_sources_with_stdin(self) -> None:
1443         include = ""
1444         exclude = r"/exclude/|a\.py"
1445         src = "-"
1446         report = black.Report()
1447         expected = [Path("-")]
1448         sources = list(
1449             black.get_sources(
1450                 ctx=FakeContext(),
1451                 src=(src,),
1452                 quiet=True,
1453                 verbose=False,
1454                 include=re.compile(include),
1455                 exclude=re.compile(exclude),
1456                 extend_exclude=None,
1457                 force_exclude=None,
1458                 report=report,
1459                 stdin_filename=None,
1460             )
1461         )
1462         self.assertEqual(sorted(expected), sorted(sources))
1463
1464     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1465     def test_get_sources_with_stdin_filename(self) -> None:
1466         include = ""
1467         exclude = r"/exclude/|a\.py"
1468         src = "-"
1469         report = black.Report()
1470         stdin_filename = str(THIS_DIR / "data/collections.py")
1471         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1472         sources = list(
1473             black.get_sources(
1474                 ctx=FakeContext(),
1475                 src=(src,),
1476                 quiet=True,
1477                 verbose=False,
1478                 include=re.compile(include),
1479                 exclude=re.compile(exclude),
1480                 extend_exclude=None,
1481                 force_exclude=None,
1482                 report=report,
1483                 stdin_filename=stdin_filename,
1484             )
1485         )
1486         self.assertEqual(sorted(expected), sorted(sources))
1487
1488     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1489     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1490         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1491         # file being passed directly. This is the same as
1492         # test_exclude_for_issue_1572
1493         path = THIS_DIR / "data" / "include_exclude_tests"
1494         include = ""
1495         exclude = r"/exclude/|a\.py"
1496         src = "-"
1497         report = black.Report()
1498         stdin_filename = str(path / "b/exclude/a.py")
1499         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1500         sources = list(
1501             black.get_sources(
1502                 ctx=FakeContext(),
1503                 src=(src,),
1504                 quiet=True,
1505                 verbose=False,
1506                 include=re.compile(include),
1507                 exclude=re.compile(exclude),
1508                 extend_exclude=None,
1509                 force_exclude=None,
1510                 report=report,
1511                 stdin_filename=stdin_filename,
1512             )
1513         )
1514         self.assertEqual(sorted(expected), sorted(sources))
1515
1516     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1517     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1518         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1519         # file being passed directly. This is the same as
1520         # test_exclude_for_issue_1572
1521         path = THIS_DIR / "data" / "include_exclude_tests"
1522         include = ""
1523         extend_exclude = r"/exclude/|a\.py"
1524         src = "-"
1525         report = black.Report()
1526         stdin_filename = str(path / "b/exclude/a.py")
1527         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1528         sources = list(
1529             black.get_sources(
1530                 ctx=FakeContext(),
1531                 src=(src,),
1532                 quiet=True,
1533                 verbose=False,
1534                 include=re.compile(include),
1535                 exclude=re.compile(""),
1536                 extend_exclude=re.compile(extend_exclude),
1537                 force_exclude=None,
1538                 report=report,
1539                 stdin_filename=stdin_filename,
1540             )
1541         )
1542         self.assertEqual(sorted(expected), sorted(sources))
1543
1544     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1545     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1546         # Force exclude should exclude the file when passing it through
1547         # stdin_filename
1548         path = THIS_DIR / "data" / "include_exclude_tests"
1549         include = ""
1550         force_exclude = r"/exclude/|a\.py"
1551         src = "-"
1552         report = black.Report()
1553         stdin_filename = str(path / "b/exclude/a.py")
1554         sources = list(
1555             black.get_sources(
1556                 ctx=FakeContext(),
1557                 src=(src,),
1558                 quiet=True,
1559                 verbose=False,
1560                 include=re.compile(include),
1561                 exclude=re.compile(""),
1562                 extend_exclude=None,
1563                 force_exclude=re.compile(force_exclude),
1564                 report=report,
1565                 stdin_filename=stdin_filename,
1566             )
1567         )
1568         self.assertEqual([], sorted(sources))
1569
1570     def test_reformat_one_with_stdin(self) -> None:
1571         with patch(
1572             "black.format_stdin_to_stdout",
1573             return_value=lambda *args, **kwargs: black.Changed.YES,
1574         ) as fsts:
1575             report = MagicMock()
1576             path = Path("-")
1577             black.reformat_one(
1578                 path,
1579                 fast=True,
1580                 write_back=black.WriteBack.YES,
1581                 mode=DEFAULT_MODE,
1582                 report=report,
1583             )
1584             fsts.assert_called_once()
1585             report.done.assert_called_with(path, black.Changed.YES)
1586
1587     def test_reformat_one_with_stdin_filename(self) -> None:
1588         with patch(
1589             "black.format_stdin_to_stdout",
1590             return_value=lambda *args, **kwargs: black.Changed.YES,
1591         ) as fsts:
1592             report = MagicMock()
1593             p = "foo.py"
1594             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1595             expected = Path(p)
1596             black.reformat_one(
1597                 path,
1598                 fast=True,
1599                 write_back=black.WriteBack.YES,
1600                 mode=DEFAULT_MODE,
1601                 report=report,
1602             )
1603             fsts.assert_called_once_with(
1604                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1605             )
1606             # __BLACK_STDIN_FILENAME__ should have been stripped
1607             report.done.assert_called_with(expected, black.Changed.YES)
1608
1609     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1610         with patch(
1611             "black.format_stdin_to_stdout",
1612             return_value=lambda *args, **kwargs: black.Changed.YES,
1613         ) as fsts:
1614             report = MagicMock()
1615             p = "foo.pyi"
1616             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1617             expected = Path(p)
1618             black.reformat_one(
1619                 path,
1620                 fast=True,
1621                 write_back=black.WriteBack.YES,
1622                 mode=DEFAULT_MODE,
1623                 report=report,
1624             )
1625             fsts.assert_called_once_with(
1626                 fast=True,
1627                 write_back=black.WriteBack.YES,
1628                 mode=replace(DEFAULT_MODE, is_pyi=True),
1629             )
1630             # __BLACK_STDIN_FILENAME__ should have been stripped
1631             report.done.assert_called_with(expected, black.Changed.YES)
1632
1633     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1634         with patch(
1635             "black.format_stdin_to_stdout",
1636             return_value=lambda *args, **kwargs: black.Changed.YES,
1637         ) as fsts:
1638             report = MagicMock()
1639             # Even with an existing file, since we are forcing stdin, black
1640             # should output to stdout and not modify the file inplace
1641             p = Path(str(THIS_DIR / "data/collections.py"))
1642             # Make sure is_file actually returns True
1643             self.assertTrue(p.is_file())
1644             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1645             expected = Path(p)
1646             black.reformat_one(
1647                 path,
1648                 fast=True,
1649                 write_back=black.WriteBack.YES,
1650                 mode=DEFAULT_MODE,
1651                 report=report,
1652             )
1653             fsts.assert_called_once()
1654             # __BLACK_STDIN_FILENAME__ should have been stripped
1655             report.done.assert_called_with(expected, black.Changed.YES)
1656
1657     def test_reformat_one_with_stdin_empty(self) -> None:
1658         output = io.StringIO()
1659         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1660             try:
1661                 black.format_stdin_to_stdout(
1662                     fast=True,
1663                     content="",
1664                     write_back=black.WriteBack.YES,
1665                     mode=DEFAULT_MODE,
1666                 )
1667             except io.UnsupportedOperation:
1668                 pass  # StringIO does not support detach
1669             assert output.getvalue() == ""
1670
1671     def test_gitignore_exclude(self) -> None:
1672         path = THIS_DIR / "data" / "include_exclude_tests"
1673         include = re.compile(r"\.pyi?$")
1674         exclude = re.compile(r"")
1675         report = black.Report()
1676         gitignore = PathSpec.from_lines(
1677             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1678         )
1679         sources: List[Path] = []
1680         expected = [
1681             Path(path / "b/dont_exclude/a.py"),
1682             Path(path / "b/dont_exclude/a.pyi"),
1683         ]
1684         this_abs = THIS_DIR.resolve()
1685         sources.extend(
1686             black.gen_python_files(
1687                 path.iterdir(),
1688                 this_abs,
1689                 include,
1690                 exclude,
1691                 None,
1692                 None,
1693                 report,
1694                 gitignore,
1695                 verbose=False,
1696                 quiet=False,
1697             )
1698         )
1699         self.assertEqual(sorted(expected), sorted(sources))
1700
1701     def test_nested_gitignore(self) -> None:
1702         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1703         include = re.compile(r"\.pyi?$")
1704         exclude = re.compile(r"")
1705         root_gitignore = black.files.get_gitignore(path)
1706         report = black.Report()
1707         expected: List[Path] = [
1708             Path(path / "x.py"),
1709             Path(path / "root/b.py"),
1710             Path(path / "root/c.py"),
1711             Path(path / "root/child/c.py"),
1712         ]
1713         this_abs = THIS_DIR.resolve()
1714         sources = list(
1715             black.gen_python_files(
1716                 path.iterdir(),
1717                 this_abs,
1718                 include,
1719                 exclude,
1720                 None,
1721                 None,
1722                 report,
1723                 root_gitignore,
1724                 verbose=False,
1725                 quiet=False,
1726             )
1727         )
1728         self.assertEqual(sorted(expected), sorted(sources))
1729
1730     def test_empty_include(self) -> None:
1731         path = THIS_DIR / "data" / "include_exclude_tests"
1732         report = black.Report()
1733         gitignore = PathSpec.from_lines("gitwildmatch", [])
1734         empty = re.compile(r"")
1735         sources: List[Path] = []
1736         expected = [
1737             Path(path / "b/exclude/a.pie"),
1738             Path(path / "b/exclude/a.py"),
1739             Path(path / "b/exclude/a.pyi"),
1740             Path(path / "b/dont_exclude/a.pie"),
1741             Path(path / "b/dont_exclude/a.py"),
1742             Path(path / "b/dont_exclude/a.pyi"),
1743             Path(path / "b/.definitely_exclude/a.pie"),
1744             Path(path / "b/.definitely_exclude/a.py"),
1745             Path(path / "b/.definitely_exclude/a.pyi"),
1746             Path(path / ".gitignore"),
1747             Path(path / "pyproject.toml"),
1748         ]
1749         this_abs = THIS_DIR.resolve()
1750         sources.extend(
1751             black.gen_python_files(
1752                 path.iterdir(),
1753                 this_abs,
1754                 empty,
1755                 re.compile(black.DEFAULT_EXCLUDES),
1756                 None,
1757                 None,
1758                 report,
1759                 gitignore,
1760                 verbose=False,
1761                 quiet=False,
1762             )
1763         )
1764         self.assertEqual(sorted(expected), sorted(sources))
1765
1766     def test_extend_exclude(self) -> None:
1767         path = THIS_DIR / "data" / "include_exclude_tests"
1768         report = black.Report()
1769         gitignore = PathSpec.from_lines("gitwildmatch", [])
1770         sources: List[Path] = []
1771         expected = [
1772             Path(path / "b/exclude/a.py"),
1773             Path(path / "b/dont_exclude/a.py"),
1774         ]
1775         this_abs = THIS_DIR.resolve()
1776         sources.extend(
1777             black.gen_python_files(
1778                 path.iterdir(),
1779                 this_abs,
1780                 re.compile(black.DEFAULT_INCLUDES),
1781                 re.compile(r"\.pyi$"),
1782                 re.compile(r"\.definitely_exclude"),
1783                 None,
1784                 report,
1785                 gitignore,
1786                 verbose=False,
1787                 quiet=False,
1788             )
1789         )
1790         self.assertEqual(sorted(expected), sorted(sources))
1791
1792     def test_invalid_cli_regex(self) -> None:
1793         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1794             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1795
1796     def test_required_version_matches_version(self) -> None:
1797         self.invokeBlack(
1798             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1799         )
1800
1801     def test_required_version_does_not_match_version(self) -> None:
1802         self.invokeBlack(
1803             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1804         )
1805
1806     def test_preserves_line_endings(self) -> None:
1807         with TemporaryDirectory() as workspace:
1808             test_file = Path(workspace) / "test.py"
1809             for nl in ["\n", "\r\n"]:
1810                 contents = nl.join(["def f(  ):", "    pass"])
1811                 test_file.write_bytes(contents.encode())
1812                 ff(test_file, write_back=black.WriteBack.YES)
1813                 updated_contents: bytes = test_file.read_bytes()
1814                 self.assertIn(nl.encode(), updated_contents)
1815                 if nl == "\n":
1816                     self.assertNotIn(b"\r\n", updated_contents)
1817
1818     def test_preserves_line_endings_via_stdin(self) -> None:
1819         for nl in ["\n", "\r\n"]:
1820             contents = nl.join(["def f(  ):", "    pass"])
1821             runner = BlackRunner()
1822             result = runner.invoke(
1823                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1824             )
1825             self.assertEqual(result.exit_code, 0)
1826             output = result.stdout_bytes
1827             self.assertIn(nl.encode("utf8"), output)
1828             if nl == "\n":
1829                 self.assertNotIn(b"\r\n", output)
1830
1831     def test_assert_equivalent_different_asts(self) -> None:
1832         with self.assertRaises(AssertionError):
1833             black.assert_equivalent("{}", "None")
1834
1835     def test_symlink_out_of_root_directory(self) -> None:
1836         path = MagicMock()
1837         root = THIS_DIR.resolve()
1838         child = MagicMock()
1839         include = re.compile(black.DEFAULT_INCLUDES)
1840         exclude = re.compile(black.DEFAULT_EXCLUDES)
1841         report = black.Report()
1842         gitignore = PathSpec.from_lines("gitwildmatch", [])
1843         # `child` should behave like a symlink which resolved path is clearly
1844         # outside of the `root` directory.
1845         path.iterdir.return_value = [child]
1846         child.resolve.return_value = Path("/a/b/c")
1847         child.as_posix.return_value = "/a/b/c"
1848         child.is_symlink.return_value = True
1849         try:
1850             list(
1851                 black.gen_python_files(
1852                     path.iterdir(),
1853                     root,
1854                     include,
1855                     exclude,
1856                     None,
1857                     None,
1858                     report,
1859                     gitignore,
1860                     verbose=False,
1861                     quiet=False,
1862                 )
1863             )
1864         except ValueError as ve:
1865             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1866         path.iterdir.assert_called_once()
1867         child.resolve.assert_called_once()
1868         child.is_symlink.assert_called_once()
1869         # `child` should behave like a strange file which resolved path is clearly
1870         # outside of the `root` directory.
1871         child.is_symlink.return_value = False
1872         with self.assertRaises(ValueError):
1873             list(
1874                 black.gen_python_files(
1875                     path.iterdir(),
1876                     root,
1877                     include,
1878                     exclude,
1879                     None,
1880                     None,
1881                     report,
1882                     gitignore,
1883                     verbose=False,
1884                     quiet=False,
1885                 )
1886             )
1887         path.iterdir.assert_called()
1888         self.assertEqual(path.iterdir.call_count, 2)
1889         child.resolve.assert_called()
1890         self.assertEqual(child.resolve.call_count, 2)
1891         child.is_symlink.assert_called()
1892         self.assertEqual(child.is_symlink.call_count, 2)
1893
1894     def test_shhh_click(self) -> None:
1895         try:
1896             from click import _unicodefun
1897         except ModuleNotFoundError:
1898             self.skipTest("Incompatible Click version")
1899         if not hasattr(_unicodefun, "_verify_python3_env"):
1900             self.skipTest("Incompatible Click version")
1901         # First, let's see if Click is crashing with a preferred ASCII charset.
1902         with patch("locale.getpreferredencoding") as gpe:
1903             gpe.return_value = "ASCII"
1904             with self.assertRaises(RuntimeError):
1905                 _unicodefun._verify_python3_env()  # type: ignore
1906         # Now, let's silence Click...
1907         black.patch_click()
1908         # ...and confirm it's silent.
1909         with patch("locale.getpreferredencoding") as gpe:
1910             gpe.return_value = "ASCII"
1911             try:
1912                 _unicodefun._verify_python3_env()  # type: ignore
1913             except RuntimeError as re:
1914                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1915
1916     def test_root_logger_not_used_directly(self) -> None:
1917         def fail(*args: Any, **kwargs: Any) -> None:
1918             self.fail("Record created with root logger")
1919
1920         with patch.multiple(
1921             logging.root,
1922             debug=fail,
1923             info=fail,
1924             warning=fail,
1925             error=fail,
1926             critical=fail,
1927             log=fail,
1928         ):
1929             ff(THIS_DIR / "util.py")
1930
1931     def test_invalid_config_return_code(self) -> None:
1932         tmp_file = Path(black.dump_to_file())
1933         try:
1934             tmp_config = Path(black.dump_to_file())
1935             tmp_config.unlink()
1936             args = ["--config", str(tmp_config), str(tmp_file)]
1937             self.invokeBlack(args, exit_code=2, ignore_config=False)
1938         finally:
1939             tmp_file.unlink()
1940
1941     def test_parse_pyproject_toml(self) -> None:
1942         test_toml_file = THIS_DIR / "test.toml"
1943         config = black.parse_pyproject_toml(str(test_toml_file))
1944         self.assertEqual(config["verbose"], 1)
1945         self.assertEqual(config["check"], "no")
1946         self.assertEqual(config["diff"], "y")
1947         self.assertEqual(config["color"], True)
1948         self.assertEqual(config["line_length"], 79)
1949         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1950         self.assertEqual(config["exclude"], r"\.pyi?$")
1951         self.assertEqual(config["include"], r"\.py?$")
1952
1953     def test_read_pyproject_toml(self) -> None:
1954         test_toml_file = THIS_DIR / "test.toml"
1955         fake_ctx = FakeContext()
1956         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1957         config = fake_ctx.default_map
1958         self.assertEqual(config["verbose"], "1")
1959         self.assertEqual(config["check"], "no")
1960         self.assertEqual(config["diff"], "y")
1961         self.assertEqual(config["color"], "True")
1962         self.assertEqual(config["line_length"], "79")
1963         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1964         self.assertEqual(config["exclude"], r"\.pyi?$")
1965         self.assertEqual(config["include"], r"\.py?$")
1966
1967     def test_find_project_root(self) -> None:
1968         with TemporaryDirectory() as workspace:
1969             root = Path(workspace)
1970             test_dir = root / "test"
1971             test_dir.mkdir()
1972
1973             src_dir = root / "src"
1974             src_dir.mkdir()
1975
1976             root_pyproject = root / "pyproject.toml"
1977             root_pyproject.touch()
1978             src_pyproject = src_dir / "pyproject.toml"
1979             src_pyproject.touch()
1980             src_python = src_dir / "foo.py"
1981             src_python.touch()
1982
1983             self.assertEqual(
1984                 black.find_project_root((src_dir, test_dir)), root.resolve()
1985             )
1986             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1987             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1988
1989     @patch(
1990         "black.files.find_user_pyproject_toml",
1991         black.files.find_user_pyproject_toml.__wrapped__,
1992     )
1993     def test_find_user_pyproject_toml_linux(self) -> None:
1994         if system() == "Windows":
1995             return
1996
1997         # Test if XDG_CONFIG_HOME is checked
1998         with TemporaryDirectory() as workspace:
1999             tmp_user_config = Path(workspace) / "black"
2000             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
2001                 self.assertEqual(
2002                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
2003                 )
2004
2005         # Test fallback for XDG_CONFIG_HOME
2006         with patch.dict("os.environ"):
2007             os.environ.pop("XDG_CONFIG_HOME", None)
2008             fallback_user_config = Path("~/.config").expanduser() / "black"
2009             self.assertEqual(
2010                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
2011             )
2012
2013     def test_find_user_pyproject_toml_windows(self) -> None:
2014         if system() != "Windows":
2015             return
2016
2017         user_config_path = Path.home() / ".black"
2018         self.assertEqual(
2019             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2020         )
2021
2022     def test_bpo_33660_workaround(self) -> None:
2023         if system() == "Windows":
2024             return
2025
2026         # https://bugs.python.org/issue33660
2027         root = Path("/")
2028         with change_directory(root):
2029             path = Path("workspace") / "project"
2030             report = black.Report(verbose=True)
2031             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2032             self.assertEqual(normalized_path, "workspace/project")
2033
2034     def test_newline_comment_interaction(self) -> None:
2035         source = "class A:\\\r\n# type: ignore\n pass\n"
2036         output = black.format_str(source, mode=DEFAULT_MODE)
2037         black.assert_stable(source, output, mode=DEFAULT_MODE)
2038
2039     def test_bpo_2142_workaround(self) -> None:
2040
2041         # https://bugs.python.org/issue2142
2042
2043         source, _ = read_data("missing_final_newline.py")
2044         # read_data adds a trailing newline
2045         source = source.rstrip()
2046         expected, _ = read_data("missing_final_newline.diff")
2047         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2048         diff_header = re.compile(
2049             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2050             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2051         )
2052         try:
2053             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2054             self.assertEqual(result.exit_code, 0)
2055         finally:
2056             os.unlink(tmp_file)
2057         actual = result.output
2058         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2059         self.assertEqual(actual, expected)
2060
2061     @pytest.mark.python2
2062     def test_docstring_reformat_for_py27(self) -> None:
2063         """
2064         Check that stripping trailing whitespace from Python 2 docstrings
2065         doesn't trigger a "not equivalent to source" error
2066         """
2067         source = (
2068             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2069         )
2070         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2071
2072         result = CliRunner().invoke(
2073             black.main,
2074             ["-", "-q", "--target-version=py27"],
2075             input=BytesIO(source),
2076         )
2077
2078         self.assertEqual(result.exit_code, 0)
2079         actual = result.output
2080         self.assertFormatEqual(actual, expected)
2081
2082     @staticmethod
2083     def compare_results(
2084         result: click.testing.Result, expected_value: str, expected_exit_code: int
2085     ) -> None:
2086         """Helper method to test the value and exit code of a click Result."""
2087         assert (
2088             result.output == expected_value
2089         ), "The output did not match the expected value."
2090         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
2091
2092     def test_code_option(self) -> None:
2093         """Test the code option with no changes."""
2094         code = 'print("Hello world")\n'
2095         args = ["--code", code]
2096         result = CliRunner().invoke(black.main, args)
2097
2098         self.compare_results(result, code, 0)
2099
2100     def test_code_option_changed(self) -> None:
2101         """Test the code option when changes are required."""
2102         code = "print('hello world')"
2103         formatted = black.format_str(code, mode=DEFAULT_MODE)
2104
2105         args = ["--code", code]
2106         result = CliRunner().invoke(black.main, args)
2107
2108         self.compare_results(result, formatted, 0)
2109
2110     def test_code_option_check(self) -> None:
2111         """Test the code option when check is passed."""
2112         args = ["--check", "--code", 'print("Hello world")\n']
2113         result = CliRunner().invoke(black.main, args)
2114         self.compare_results(result, "", 0)
2115
2116     def test_code_option_check_changed(self) -> None:
2117         """Test the code option when changes are required, and check is passed."""
2118         args = ["--check", "--code", "print('hello world')"]
2119         result = CliRunner().invoke(black.main, args)
2120         self.compare_results(result, "", 1)
2121
2122     def test_code_option_diff(self) -> None:
2123         """Test the code option when diff is passed."""
2124         code = "print('hello world')"
2125         formatted = black.format_str(code, mode=DEFAULT_MODE)
2126         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2127
2128         args = ["--diff", "--code", code]
2129         result = CliRunner().invoke(black.main, args)
2130
2131         # Remove time from diff
2132         output = DIFF_TIME.sub("", result.output)
2133
2134         assert output == result_diff, "The output did not match the expected value."
2135         assert result.exit_code == 0, "The exit code is incorrect."
2136
2137     def test_code_option_color_diff(self) -> None:
2138         """Test the code option when color and diff are passed."""
2139         code = "print('hello world')"
2140         formatted = black.format_str(code, mode=DEFAULT_MODE)
2141
2142         result_diff = diff(code, formatted, "STDIN", "STDOUT")
2143         result_diff = color_diff(result_diff)
2144
2145         args = ["--diff", "--color", "--code", code]
2146         result = CliRunner().invoke(black.main, args)
2147
2148         # Remove time from diff
2149         output = DIFF_TIME.sub("", result.output)
2150
2151         assert output == result_diff, "The output did not match the expected value."
2152         assert result.exit_code == 0, "The exit code is incorrect."
2153
2154     def test_code_option_safe(self) -> None:
2155         """Test that the code option throws an error when the sanity checks fail."""
2156         # Patch black.assert_equivalent to ensure the sanity checks fail
2157         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2158             code = 'print("Hello world")'
2159             error_msg = f"{code}\nerror: cannot format <string>: \n"
2160
2161             args = ["--safe", "--code", code]
2162             result = CliRunner().invoke(black.main, args)
2163
2164             self.compare_results(result, error_msg, 123)
2165
2166     def test_code_option_fast(self) -> None:
2167         """Test that the code option ignores errors when the sanity checks fail."""
2168         # Patch black.assert_equivalent to ensure the sanity checks fail
2169         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
2170             code = 'print("Hello world")'
2171             formatted = black.format_str(code, mode=DEFAULT_MODE)
2172
2173             args = ["--fast", "--code", code]
2174             result = CliRunner().invoke(black.main, args)
2175
2176             self.compare_results(result, formatted, 0)
2177
2178     def test_code_option_config(self) -> None:
2179         """
2180         Test that the code option finds the pyproject.toml in the current directory.
2181         """
2182         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2183             args = ["--code", "print"]
2184             CliRunner().invoke(black.main, args)
2185
2186             pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
2187             assert (
2188                 len(parse.mock_calls) >= 1
2189             ), "Expected config parse to be called with the current directory."
2190
2191             _, call_args, _ = parse.mock_calls[0]
2192             assert (
2193                 call_args[0].lower() == str(pyproject_path).lower()
2194             ), "Incorrect config loaded."
2195
2196     def test_code_option_parent_config(self) -> None:
2197         """
2198         Test that the code option finds the pyproject.toml in the parent directory.
2199         """
2200         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
2201             with change_directory(Path("tests")):
2202                 args = ["--code", "print"]
2203                 CliRunner().invoke(black.main, args)
2204
2205                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
2206                 assert (
2207                     len(parse.mock_calls) >= 1
2208                 ), "Expected config parse to be called with the current directory."
2209
2210                 _, call_args, _ = parse.mock_calls[0]
2211                 assert (
2212                     call_args[0].lower() == str(pyproject_path).lower()
2213                 ), "Incorrect config loaded."
2214
2215
2216 with open(black.__file__, "r", encoding="utf-8") as _bf:
2217     black_source_lines = _bf.readlines()
2218
2219
2220 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2221     """Show function calls `from black/__init__.py` as they happen.
2222
2223     Register this with `sys.settrace()` in a test you're debugging.
2224     """
2225     if event != "call":
2226         return tracefunc
2227
2228     stack = len(inspect.stack()) - 19
2229     stack *= 2
2230     filename = frame.f_code.co_filename
2231     lineno = frame.f_lineno
2232     func_sig_lineno = lineno - 1
2233     funcname = black_source_lines[func_sig_lineno].strip()
2234     while funcname.startswith("@"):
2235         func_sig_lineno += 1
2236         funcname = black_source_lines[func_sig_lineno].strip()
2237     if "black/__init__.py" in filename:
2238         print(f"{' ' * stack}{lineno}:{funcname}")
2239     return tracefunc
2240
2241
2242 if __name__ == "__main__":
2243     unittest.main(module="test_black")