]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add compatible configuration files. (psf#1789) (#1792)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import unittest
28 from unittest.mock import patch, MagicMock
29
30 import click
31 from click import unstyle
32 from click.testing import CliRunner
33
34 import black
35 from black import Feature, TargetVersion
36
37 from pathspec import PathSpec
38
39 # Import other test classes
40 from tests.util import (
41     THIS_DIR,
42     read_data,
43     DETERMINISTIC_HEADER,
44     BlackBaseTestCase,
45     DEFAULT_MODE,
46     fs,
47     ff,
48     dump_to_stderr,
49 )
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 THIS_FILE = Path(__file__)
54 PY36_VERSIONS = {
55     TargetVersion.PY36,
56     TargetVersion.PY37,
57     TargetVersion.PY38,
58     TargetVersion.PY39,
59 }
60 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
61 T = TypeVar("T")
62 R = TypeVar("R")
63
64
65 @contextmanager
66 def cache_dir(exists: bool = True) -> Iterator[Path]:
67     with TemporaryDirectory() as workspace:
68         cache_dir = Path(workspace)
69         if not exists:
70             cache_dir = cache_dir / "new"
71         with patch("black.CACHE_DIR", cache_dir):
72             yield cache_dir
73
74
75 @contextmanager
76 def event_loop() -> Iterator[None]:
77     policy = asyncio.get_event_loop_policy()
78     loop = policy.new_event_loop()
79     asyncio.set_event_loop(loop)
80     try:
81         yield
82
83     finally:
84         loop.close()
85
86
87 class FakeContext(click.Context):
88     """A fake click Context for when calling functions that need it."""
89
90     def __init__(self) -> None:
91         self.default_map: Dict[str, Any] = {}
92
93
94 class FakeParameter(click.Parameter):
95     """A fake click Parameter for when calling functions that need it."""
96
97     def __init__(self) -> None:
98         pass
99
100
101 class BlackRunner(CliRunner):
102     """Modify CliRunner so that stderr is not merged with stdout.
103
104     This is a hack that can be removed once we depend on Click 7.x"""
105
106     def __init__(self) -> None:
107         self.stderrbuf = BytesIO()
108         self.stdoutbuf = BytesIO()
109         self.stdout_bytes = b""
110         self.stderr_bytes = b""
111         super().__init__()
112
113     @contextmanager
114     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
115         with super().isolation(*args, **kwargs) as output:
116             try:
117                 hold_stderr = sys.stderr
118                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
119                 yield output
120             finally:
121                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
122                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
123                 sys.stderr = hold_stderr
124
125
126 class BlackTestCase(BlackBaseTestCase):
127     def invokeBlack(
128         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
129     ) -> None:
130         runner = BlackRunner()
131         if ignore_config:
132             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
133         result = runner.invoke(black.main, args)
134         self.assertEqual(
135             result.exit_code,
136             exit_code,
137             msg=(
138                 f"Failed with args: {args}\n"
139                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
140                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
141                 f"exception: {result.exception}"
142             ),
143         )
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_empty(self) -> None:
147         source = expected = ""
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, DEFAULT_MODE)
152
153     def test_empty_ff(self) -> None:
154         expected = ""
155         tmp_file = Path(black.dump_to_file())
156         try:
157             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
158             with open(tmp_file, encoding="utf8") as f:
159                 actual = f.read()
160         finally:
161             os.unlink(tmp_file)
162         self.assertFormatEqual(expected, actual)
163
164     def test_piping(self) -> None:
165         source, expected = read_data("src/black/__init__", data=False)
166         result = BlackRunner().invoke(
167             black.main,
168             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
169             input=BytesIO(source.encode("utf8")),
170         )
171         self.assertEqual(result.exit_code, 0)
172         self.assertFormatEqual(expected, result.output)
173         black.assert_equivalent(source, result.output)
174         black.assert_stable(source, result.output, DEFAULT_MODE)
175
176     def test_piping_diff(self) -> None:
177         diff_header = re.compile(
178             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
179             r"\+\d\d\d\d"
180         )
181         source, _ = read_data("expression.py")
182         expected, _ = read_data("expression.diff")
183         config = THIS_DIR / "data" / "empty_pyproject.toml"
184         args = [
185             "-",
186             "--fast",
187             f"--line-length={black.DEFAULT_LINE_LENGTH}",
188             "--diff",
189             f"--config={config}",
190         ]
191         result = BlackRunner().invoke(
192             black.main, args, input=BytesIO(source.encode("utf8"))
193         )
194         self.assertEqual(result.exit_code, 0)
195         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
196         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
197         self.assertEqual(expected, actual)
198
199     def test_piping_diff_with_color(self) -> None:
200         source, _ = read_data("expression.py")
201         config = THIS_DIR / "data" / "empty_pyproject.toml"
202         args = [
203             "-",
204             "--fast",
205             f"--line-length={black.DEFAULT_LINE_LENGTH}",
206             "--diff",
207             "--color",
208             f"--config={config}",
209         ]
210         result = BlackRunner().invoke(
211             black.main, args, input=BytesIO(source.encode("utf8"))
212         )
213         actual = result.output
214         # Again, the contents are checked in a different test, so only look for colors.
215         self.assertIn("\033[1;37m", actual)
216         self.assertIn("\033[36m", actual)
217         self.assertIn("\033[32m", actual)
218         self.assertIn("\033[31m", actual)
219         self.assertIn("\033[0m", actual)
220
221     @patch("black.dump_to_file", dump_to_stderr)
222     def _test_wip(self) -> None:
223         source, expected = read_data("wip")
224         sys.settrace(tracefunc)
225         mode = replace(
226             DEFAULT_MODE,
227             experimental_string_processing=False,
228             target_versions={black.TargetVersion.PY38},
229         )
230         actual = fs(source, mode=mode)
231         sys.settrace(None)
232         self.assertFormatEqual(expected, actual)
233         black.assert_equivalent(source, actual)
234         black.assert_stable(source, actual, black.FileMode())
235
236     @unittest.expectedFailure
237     @patch("black.dump_to_file", dump_to_stderr)
238     def test_trailing_comma_optional_parens_stability1(self) -> None:
239         source, _expected = read_data("trailing_comma_optional_parens1")
240         actual = fs(source)
241         black.assert_stable(source, actual, DEFAULT_MODE)
242
243     @unittest.expectedFailure
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_trailing_comma_optional_parens_stability2(self) -> None:
246         source, _expected = read_data("trailing_comma_optional_parens2")
247         actual = fs(source)
248         black.assert_stable(source, actual, DEFAULT_MODE)
249
250     @unittest.expectedFailure
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability3(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens3")
254         actual = fs(source)
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_pep_572(self) -> None:
259         source, expected = read_data("pep_572")
260         actual = fs(source)
261         self.assertFormatEqual(expected, actual)
262         black.assert_stable(source, actual, DEFAULT_MODE)
263         if sys.version_info >= (3, 8):
264             black.assert_equivalent(source, actual)
265
266     def test_pep_572_version_detection(self) -> None:
267         source, _ = read_data("pep_572")
268         root = black.lib2to3_parse(source)
269         features = black.get_features_used(root)
270         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
271         versions = black.detect_target_versions(root)
272         self.assertIn(black.TargetVersion.PY38, versions)
273
274     def test_expression_ff(self) -> None:
275         source, expected = read_data("expression")
276         tmp_file = Path(black.dump_to_file(source))
277         try:
278             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
279             with open(tmp_file, encoding="utf8") as f:
280                 actual = f.read()
281         finally:
282             os.unlink(tmp_file)
283         self.assertFormatEqual(expected, actual)
284         with patch("black.dump_to_file", dump_to_stderr):
285             black.assert_equivalent(source, actual)
286             black.assert_stable(source, actual, DEFAULT_MODE)
287
288     def test_expression_diff(self) -> None:
289         source, _ = read_data("expression.py")
290         expected, _ = read_data("expression.diff")
291         tmp_file = Path(black.dump_to_file(source))
292         diff_header = re.compile(
293             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
294             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
295         )
296         try:
297             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
304         if expected != actual:
305             dump = black.dump_to_file(actual)
306             msg = (
307                 "Expected diff isn't equal to the actual. If you made changes to"
308                 " expression.py and this is an anticipated difference, overwrite"
309                 f" tests/data/expression.diff with {dump}"
310             )
311             self.assertEqual(expected, actual, msg)
312
313     def test_expression_diff_with_color(self) -> None:
314         source, _ = read_data("expression.py")
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file)]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_pep_570(self) -> None:
334         source, expected = read_data("pep_570")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_stable(source, actual, DEFAULT_MODE)
338         if sys.version_info >= (3, 8):
339             black.assert_equivalent(source, actual)
340
341     def test_detect_pos_only_arguments(self) -> None:
342         source, _ = read_data("pep_570")
343         root = black.lib2to3_parse(source)
344         features = black.get_features_used(root)
345         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
346         versions = black.detect_target_versions(root)
347         self.assertIn(black.TargetVersion.PY38, versions)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_string_quotes(self) -> None:
351         source, expected = read_data("string_quotes")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_equivalent(source, actual)
355         black.assert_stable(source, actual, DEFAULT_MODE)
356         mode = replace(DEFAULT_MODE, string_normalization=False)
357         not_normalized = fs(source, mode=mode)
358         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
359         black.assert_equivalent(source, not_normalized)
360         black.assert_stable(source, not_normalized, mode=mode)
361
362     @patch("black.dump_to_file", dump_to_stderr)
363     def test_docstring_no_string_normalization(self) -> None:
364         """Like test_docstring but with string normalization off."""
365         source, expected = read_data("docstring_no_string_normalization")
366         mode = replace(DEFAULT_MODE, string_normalization=False)
367         actual = fs(source, mode=mode)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, mode)
371
372     def test_long_strings_flag_disabled(self) -> None:
373         """Tests for turning off the string processing logic."""
374         source, expected = read_data("long_strings_flag_disabled")
375         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
376         actual = fs(source, mode=mode)
377         self.assertFormatEqual(expected, actual)
378         black.assert_stable(expected, actual, mode)
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_numeric_literals(self) -> None:
382         source, expected = read_data("numeric_literals")
383         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
384         actual = fs(source, mode=mode)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, mode)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_numeric_literals_ignoring_underscores(self) -> None:
391         source, expected = read_data("numeric_literals_skip_underscores")
392         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
393         actual = fs(source, mode=mode)
394         self.assertFormatEqual(expected, actual)
395         black.assert_equivalent(source, actual)
396         black.assert_stable(source, actual, mode)
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_python2_print_function(self) -> None:
400         source, expected = read_data("python2_print_function")
401         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
402         actual = fs(source, mode=mode)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, mode)
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_stub(self) -> None:
409         mode = replace(DEFAULT_MODE, is_pyi=True)
410         source, expected = read_data("stub.pyi")
411         actual = fs(source, mode=mode)
412         self.assertFormatEqual(expected, actual)
413         black.assert_stable(source, actual, mode)
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_async_as_identifier(self) -> None:
417         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
418         source, expected = read_data("async_as_identifier")
419         actual = fs(source)
420         self.assertFormatEqual(expected, actual)
421         major, minor = sys.version_info[:2]
422         if major < 3 or (major <= 3 and minor < 7):
423             black.assert_equivalent(source, actual)
424         black.assert_stable(source, actual, DEFAULT_MODE)
425         # ensure black can parse this when the target is 3.6
426         self.invokeBlack([str(source_path), "--target-version", "py36"])
427         # but not on 3.7, because async/await is no longer an identifier
428         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
429
430     @patch("black.dump_to_file", dump_to_stderr)
431     def test_python37(self) -> None:
432         source_path = (THIS_DIR / "data" / "python37.py").resolve()
433         source, expected = read_data("python37")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         major, minor = sys.version_info[:2]
437         if major > 3 or (major == 3 and minor >= 7):
438             black.assert_equivalent(source, actual)
439         black.assert_stable(source, actual, DEFAULT_MODE)
440         # ensure black can parse this when the target is 3.7
441         self.invokeBlack([str(source_path), "--target-version", "py37"])
442         # but not on 3.6, because we use async as a reserved keyword
443         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
444
445     @patch("black.dump_to_file", dump_to_stderr)
446     def test_python38(self) -> None:
447         source, expected = read_data("python38")
448         actual = fs(source)
449         self.assertFormatEqual(expected, actual)
450         major, minor = sys.version_info[:2]
451         if major > 3 or (major == 3 and minor >= 8):
452             black.assert_equivalent(source, actual)
453         black.assert_stable(source, actual, DEFAULT_MODE)
454
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_python39(self) -> None:
457         source, expected = read_data("python39")
458         actual = fs(source)
459         self.assertFormatEqual(expected, actual)
460         major, minor = sys.version_info[:2]
461         if major > 3 or (major == 3 and minor >= 9):
462             black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, DEFAULT_MODE)
464
465     def test_tab_comment_indentation(self) -> None:
466         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
467         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
468         self.assertFormatEqual(contents_spc, fs(contents_spc))
469         self.assertFormatEqual(contents_spc, fs(contents_tab))
470
471         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
472         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
473         self.assertFormatEqual(contents_spc, fs(contents_spc))
474         self.assertFormatEqual(contents_spc, fs(contents_tab))
475
476         # mixed tabs and spaces (valid Python 2 code)
477         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
478         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
479         self.assertFormatEqual(contents_spc, fs(contents_spc))
480         self.assertFormatEqual(contents_spc, fs(contents_tab))
481
482         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
483         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
484         self.assertFormatEqual(contents_spc, fs(contents_spc))
485         self.assertFormatEqual(contents_spc, fs(contents_tab))
486
487     def test_report_verbose(self) -> None:
488         report = black.Report(verbose=True)
489         out_lines = []
490         err_lines = []
491
492         def out(msg: str, **kwargs: Any) -> None:
493             out_lines.append(msg)
494
495         def err(msg: str, **kwargs: Any) -> None:
496             err_lines.append(msg)
497
498         with patch("black.out", out), patch("black.err", err):
499             report.done(Path("f1"), black.Changed.NO)
500             self.assertEqual(len(out_lines), 1)
501             self.assertEqual(len(err_lines), 0)
502             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
503             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
504             self.assertEqual(report.return_code, 0)
505             report.done(Path("f2"), black.Changed.YES)
506             self.assertEqual(len(out_lines), 2)
507             self.assertEqual(len(err_lines), 0)
508             self.assertEqual(out_lines[-1], "reformatted f2")
509             self.assertEqual(
510                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
511             )
512             report.done(Path("f3"), black.Changed.CACHED)
513             self.assertEqual(len(out_lines), 3)
514             self.assertEqual(len(err_lines), 0)
515             self.assertEqual(
516                 out_lines[-1], "f3 wasn't modified on disk since last run."
517             )
518             self.assertEqual(
519                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
520             )
521             self.assertEqual(report.return_code, 0)
522             report.check = True
523             self.assertEqual(report.return_code, 1)
524             report.check = False
525             report.failed(Path("e1"), "boom")
526             self.assertEqual(len(out_lines), 3)
527             self.assertEqual(len(err_lines), 1)
528             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
529             self.assertEqual(
530                 unstyle(str(report)),
531                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
532                 " reformat.",
533             )
534             self.assertEqual(report.return_code, 123)
535             report.done(Path("f3"), black.Changed.YES)
536             self.assertEqual(len(out_lines), 4)
537             self.assertEqual(len(err_lines), 1)
538             self.assertEqual(out_lines[-1], "reformatted f3")
539             self.assertEqual(
540                 unstyle(str(report)),
541                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
542                 " reformat.",
543             )
544             self.assertEqual(report.return_code, 123)
545             report.failed(Path("e2"), "boom")
546             self.assertEqual(len(out_lines), 4)
547             self.assertEqual(len(err_lines), 2)
548             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
549             self.assertEqual(
550                 unstyle(str(report)),
551                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
552                 " reformat.",
553             )
554             self.assertEqual(report.return_code, 123)
555             report.path_ignored(Path("wat"), "no match")
556             self.assertEqual(len(out_lines), 5)
557             self.assertEqual(len(err_lines), 2)
558             self.assertEqual(out_lines[-1], "wat ignored: no match")
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
562                 " reformat.",
563             )
564             self.assertEqual(report.return_code, 123)
565             report.done(Path("f4"), black.Changed.NO)
566             self.assertEqual(len(out_lines), 6)
567             self.assertEqual(len(err_lines), 2)
568             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
569             self.assertEqual(
570                 unstyle(str(report)),
571                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
572                 " reformat.",
573             )
574             self.assertEqual(report.return_code, 123)
575             report.check = True
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
579                 " would fail to reformat.",
580             )
581             report.check = False
582             report.diff = True
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
586                 " would fail to reformat.",
587             )
588
589     def test_report_quiet(self) -> None:
590         report = black.Report(quiet=True)
591         out_lines = []
592         err_lines = []
593
594         def out(msg: str, **kwargs: Any) -> None:
595             out_lines.append(msg)
596
597         def err(msg: str, **kwargs: Any) -> None:
598             err_lines.append(msg)
599
600         with patch("black.out", out), patch("black.err", err):
601             report.done(Path("f1"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 0)
604             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
605             self.assertEqual(report.return_code, 0)
606             report.done(Path("f2"), black.Changed.YES)
607             self.assertEqual(len(out_lines), 0)
608             self.assertEqual(len(err_lines), 0)
609             self.assertEqual(
610                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
611             )
612             report.done(Path("f3"), black.Changed.CACHED)
613             self.assertEqual(len(out_lines), 0)
614             self.assertEqual(len(err_lines), 0)
615             self.assertEqual(
616                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
617             )
618             self.assertEqual(report.return_code, 0)
619             report.check = True
620             self.assertEqual(report.return_code, 1)
621             report.check = False
622             report.failed(Path("e1"), "boom")
623             self.assertEqual(len(out_lines), 0)
624             self.assertEqual(len(err_lines), 1)
625             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
626             self.assertEqual(
627                 unstyle(str(report)),
628                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
629                 " reformat.",
630             )
631             self.assertEqual(report.return_code, 123)
632             report.done(Path("f3"), black.Changed.YES)
633             self.assertEqual(len(out_lines), 0)
634             self.assertEqual(len(err_lines), 1)
635             self.assertEqual(
636                 unstyle(str(report)),
637                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
638                 " reformat.",
639             )
640             self.assertEqual(report.return_code, 123)
641             report.failed(Path("e2"), "boom")
642             self.assertEqual(len(out_lines), 0)
643             self.assertEqual(len(err_lines), 2)
644             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
645             self.assertEqual(
646                 unstyle(str(report)),
647                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
648                 " reformat.",
649             )
650             self.assertEqual(report.return_code, 123)
651             report.path_ignored(Path("wat"), "no match")
652             self.assertEqual(len(out_lines), 0)
653             self.assertEqual(len(err_lines), 2)
654             self.assertEqual(
655                 unstyle(str(report)),
656                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
657                 " reformat.",
658             )
659             self.assertEqual(report.return_code, 123)
660             report.done(Path("f4"), black.Changed.NO)
661             self.assertEqual(len(out_lines), 0)
662             self.assertEqual(len(err_lines), 2)
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.check = True
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
673                 " would fail to reformat.",
674             )
675             report.check = False
676             report.diff = True
677             self.assertEqual(
678                 unstyle(str(report)),
679                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
680                 " would fail to reformat.",
681             )
682
683     def test_report_normal(self) -> None:
684         report = black.Report()
685         out_lines = []
686         err_lines = []
687
688         def out(msg: str, **kwargs: Any) -> None:
689             out_lines.append(msg)
690
691         def err(msg: str, **kwargs: Any) -> None:
692             err_lines.append(msg)
693
694         with patch("black.out", out), patch("black.err", err):
695             report.done(Path("f1"), black.Changed.NO)
696             self.assertEqual(len(out_lines), 0)
697             self.assertEqual(len(err_lines), 0)
698             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
699             self.assertEqual(report.return_code, 0)
700             report.done(Path("f2"), black.Changed.YES)
701             self.assertEqual(len(out_lines), 1)
702             self.assertEqual(len(err_lines), 0)
703             self.assertEqual(out_lines[-1], "reformatted f2")
704             self.assertEqual(
705                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
706             )
707             report.done(Path("f3"), black.Changed.CACHED)
708             self.assertEqual(len(out_lines), 1)
709             self.assertEqual(len(err_lines), 0)
710             self.assertEqual(out_lines[-1], "reformatted f2")
711             self.assertEqual(
712                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
713             )
714             self.assertEqual(report.return_code, 0)
715             report.check = True
716             self.assertEqual(report.return_code, 1)
717             report.check = False
718             report.failed(Path("e1"), "boom")
719             self.assertEqual(len(out_lines), 1)
720             self.assertEqual(len(err_lines), 1)
721             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
722             self.assertEqual(
723                 unstyle(str(report)),
724                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
725                 " reformat.",
726             )
727             self.assertEqual(report.return_code, 123)
728             report.done(Path("f3"), black.Changed.YES)
729             self.assertEqual(len(out_lines), 2)
730             self.assertEqual(len(err_lines), 1)
731             self.assertEqual(out_lines[-1], "reformatted f3")
732             self.assertEqual(
733                 unstyle(str(report)),
734                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
735                 " reformat.",
736             )
737             self.assertEqual(report.return_code, 123)
738             report.failed(Path("e2"), "boom")
739             self.assertEqual(len(out_lines), 2)
740             self.assertEqual(len(err_lines), 2)
741             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
742             self.assertEqual(
743                 unstyle(str(report)),
744                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
745                 " reformat.",
746             )
747             self.assertEqual(report.return_code, 123)
748             report.path_ignored(Path("wat"), "no match")
749             self.assertEqual(len(out_lines), 2)
750             self.assertEqual(len(err_lines), 2)
751             self.assertEqual(
752                 unstyle(str(report)),
753                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
754                 " reformat.",
755             )
756             self.assertEqual(report.return_code, 123)
757             report.done(Path("f4"), black.Changed.NO)
758             self.assertEqual(len(out_lines), 2)
759             self.assertEqual(len(err_lines), 2)
760             self.assertEqual(
761                 unstyle(str(report)),
762                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
763                 " reformat.",
764             )
765             self.assertEqual(report.return_code, 123)
766             report.check = True
767             self.assertEqual(
768                 unstyle(str(report)),
769                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
770                 " would fail to reformat.",
771             )
772             report.check = False
773             report.diff = True
774             self.assertEqual(
775                 unstyle(str(report)),
776                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
777                 " would fail to reformat.",
778             )
779
780     def test_lib2to3_parse(self) -> None:
781         with self.assertRaises(black.InvalidInput):
782             black.lib2to3_parse("invalid syntax")
783
784         straddling = "x + y"
785         black.lib2to3_parse(straddling)
786         black.lib2to3_parse(straddling, {TargetVersion.PY27})
787         black.lib2to3_parse(straddling, {TargetVersion.PY36})
788         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
789
790         py2_only = "print x"
791         black.lib2to3_parse(py2_only)
792         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
793         with self.assertRaises(black.InvalidInput):
794             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
795         with self.assertRaises(black.InvalidInput):
796             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
797
798         py3_only = "exec(x, end=y)"
799         black.lib2to3_parse(py3_only)
800         with self.assertRaises(black.InvalidInput):
801             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
802         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
803         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
804
805     def test_get_features_used_decorator(self) -> None:
806         # Test the feature detection of new decorator syntax
807         # since this makes some test cases of test_get_features_used()
808         # fails if it fails, this is tested first so that a useful case
809         # is identified
810         simples, relaxed = read_data("decorators")
811         # skip explanation comments at the top of the file
812         for simple_test in simples.split("##")[1:]:
813             node = black.lib2to3_parse(simple_test)
814             decorator = str(node.children[0].children[0]).strip()
815             self.assertNotIn(
816                 Feature.RELAXED_DECORATORS,
817                 black.get_features_used(node),
818                 msg=(
819                     f"decorator '{decorator}' follows python<=3.8 syntax"
820                     "but is detected as 3.9+"
821                     # f"The full node is\n{node!r}"
822                 ),
823             )
824         # skip the '# output' comment at the top of the output part
825         for relaxed_test in relaxed.split("##")[1:]:
826             node = black.lib2to3_parse(relaxed_test)
827             decorator = str(node.children[0].children[0]).strip()
828             self.assertIn(
829                 Feature.RELAXED_DECORATORS,
830                 black.get_features_used(node),
831                 msg=(
832                     f"decorator '{decorator}' uses python3.9+ syntax"
833                     "but is detected as python<=3.8"
834                     # f"The full node is\n{node!r}"
835                 ),
836             )
837
838     def test_get_features_used(self) -> None:
839         node = black.lib2to3_parse("def f(*, arg): ...\n")
840         self.assertEqual(black.get_features_used(node), set())
841         node = black.lib2to3_parse("def f(*, arg,): ...\n")
842         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
843         node = black.lib2to3_parse("f(*arg,)\n")
844         self.assertEqual(
845             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
846         )
847         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
848         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
849         node = black.lib2to3_parse("123_456\n")
850         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
851         node = black.lib2to3_parse("123456\n")
852         self.assertEqual(black.get_features_used(node), set())
853         source, expected = read_data("function")
854         node = black.lib2to3_parse(source)
855         expected_features = {
856             Feature.TRAILING_COMMA_IN_CALL,
857             Feature.TRAILING_COMMA_IN_DEF,
858             Feature.F_STRINGS,
859         }
860         self.assertEqual(black.get_features_used(node), expected_features)
861         node = black.lib2to3_parse(expected)
862         self.assertEqual(black.get_features_used(node), expected_features)
863         source, expected = read_data("expression")
864         node = black.lib2to3_parse(source)
865         self.assertEqual(black.get_features_used(node), set())
866         node = black.lib2to3_parse(expected)
867         self.assertEqual(black.get_features_used(node), set())
868
869     def test_get_future_imports(self) -> None:
870         node = black.lib2to3_parse("\n")
871         self.assertEqual(set(), black.get_future_imports(node))
872         node = black.lib2to3_parse("from __future__ import black\n")
873         self.assertEqual({"black"}, black.get_future_imports(node))
874         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
875         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
876         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
877         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
878         node = black.lib2to3_parse(
879             "from __future__ import multiple\nfrom __future__ import imports\n"
880         )
881         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
882         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
883         self.assertEqual({"black"}, black.get_future_imports(node))
884         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
885         self.assertEqual({"black"}, black.get_future_imports(node))
886         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
887         self.assertEqual(set(), black.get_future_imports(node))
888         node = black.lib2to3_parse("from some.module import black\n")
889         self.assertEqual(set(), black.get_future_imports(node))
890         node = black.lib2to3_parse(
891             "from __future__ import unicode_literals as _unicode_literals"
892         )
893         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
894         node = black.lib2to3_parse(
895             "from __future__ import unicode_literals as _lol, print"
896         )
897         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
898
899     def test_debug_visitor(self) -> None:
900         source, _ = read_data("debug_visitor.py")
901         expected, _ = read_data("debug_visitor.out")
902         out_lines = []
903         err_lines = []
904
905         def out(msg: str, **kwargs: Any) -> None:
906             out_lines.append(msg)
907
908         def err(msg: str, **kwargs: Any) -> None:
909             err_lines.append(msg)
910
911         with patch("black.out", out), patch("black.err", err):
912             black.DebugVisitor.show(source)
913         actual = "\n".join(out_lines) + "\n"
914         log_name = ""
915         if expected != actual:
916             log_name = black.dump_to_file(*out_lines)
917         self.assertEqual(
918             expected,
919             actual,
920             f"AST print out is different. Actual version dumped to {log_name}",
921         )
922
923     def test_format_file_contents(self) -> None:
924         empty = ""
925         mode = DEFAULT_MODE
926         with self.assertRaises(black.NothingChanged):
927             black.format_file_contents(empty, mode=mode, fast=False)
928         just_nl = "\n"
929         with self.assertRaises(black.NothingChanged):
930             black.format_file_contents(just_nl, mode=mode, fast=False)
931         same = "j = [1, 2, 3]\n"
932         with self.assertRaises(black.NothingChanged):
933             black.format_file_contents(same, mode=mode, fast=False)
934         different = "j = [1,2,3]"
935         expected = same
936         actual = black.format_file_contents(different, mode=mode, fast=False)
937         self.assertEqual(expected, actual)
938         invalid = "return if you can"
939         with self.assertRaises(black.InvalidInput) as e:
940             black.format_file_contents(invalid, mode=mode, fast=False)
941         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
942
943     def test_endmarker(self) -> None:
944         n = black.lib2to3_parse("\n")
945         self.assertEqual(n.type, black.syms.file_input)
946         self.assertEqual(len(n.children), 1)
947         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
948
949     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
950     def test_assertFormatEqual(self) -> None:
951         out_lines = []
952         err_lines = []
953
954         def out(msg: str, **kwargs: Any) -> None:
955             out_lines.append(msg)
956
957         def err(msg: str, **kwargs: Any) -> None:
958             err_lines.append(msg)
959
960         with patch("black.out", out), patch("black.err", err):
961             with self.assertRaises(AssertionError):
962                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
963
964         out_str = "".join(out_lines)
965         self.assertTrue("Expected tree:" in out_str)
966         self.assertTrue("Actual tree:" in out_str)
967         self.assertEqual("".join(err_lines), "")
968
969     def test_cache_broken_file(self) -> None:
970         mode = DEFAULT_MODE
971         with cache_dir() as workspace:
972             cache_file = black.get_cache_file(mode)
973             with cache_file.open("w") as fobj:
974                 fobj.write("this is not a pickle")
975             self.assertEqual(black.read_cache(mode), {})
976             src = (workspace / "test.py").resolve()
977             with src.open("w") as fobj:
978                 fobj.write("print('hello')")
979             self.invokeBlack([str(src)])
980             cache = black.read_cache(mode)
981             self.assertIn(src, cache)
982
983     def test_cache_single_file_already_cached(self) -> None:
984         mode = DEFAULT_MODE
985         with cache_dir() as workspace:
986             src = (workspace / "test.py").resolve()
987             with src.open("w") as fobj:
988                 fobj.write("print('hello')")
989             black.write_cache({}, [src], mode)
990             self.invokeBlack([str(src)])
991             with src.open("r") as fobj:
992                 self.assertEqual(fobj.read(), "print('hello')")
993
994     @event_loop()
995     def test_cache_multiple_files(self) -> None:
996         mode = DEFAULT_MODE
997         with cache_dir() as workspace, patch(
998             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
999         ):
1000             one = (workspace / "one.py").resolve()
1001             with one.open("w") as fobj:
1002                 fobj.write("print('hello')")
1003             two = (workspace / "two.py").resolve()
1004             with two.open("w") as fobj:
1005                 fobj.write("print('hello')")
1006             black.write_cache({}, [one], mode)
1007             self.invokeBlack([str(workspace)])
1008             with one.open("r") as fobj:
1009                 self.assertEqual(fobj.read(), "print('hello')")
1010             with two.open("r") as fobj:
1011                 self.assertEqual(fobj.read(), 'print("hello")\n')
1012             cache = black.read_cache(mode)
1013             self.assertIn(one, cache)
1014             self.assertIn(two, cache)
1015
1016     def test_no_cache_when_writeback_diff(self) -> None:
1017         mode = DEFAULT_MODE
1018         with cache_dir() as workspace:
1019             src = (workspace / "test.py").resolve()
1020             with src.open("w") as fobj:
1021                 fobj.write("print('hello')")
1022             with patch("black.read_cache") as read_cache, patch(
1023                 "black.write_cache"
1024             ) as write_cache:
1025                 self.invokeBlack([str(src), "--diff"])
1026                 cache_file = black.get_cache_file(mode)
1027                 self.assertFalse(cache_file.exists())
1028                 write_cache.assert_not_called()
1029                 read_cache.assert_not_called()
1030
1031     def test_no_cache_when_writeback_color_diff(self) -> None:
1032         mode = DEFAULT_MODE
1033         with cache_dir() as workspace:
1034             src = (workspace / "test.py").resolve()
1035             with src.open("w") as fobj:
1036                 fobj.write("print('hello')")
1037             with patch("black.read_cache") as read_cache, patch(
1038                 "black.write_cache"
1039             ) as write_cache:
1040                 self.invokeBlack([str(src), "--diff", "--color"])
1041                 cache_file = black.get_cache_file(mode)
1042                 self.assertFalse(cache_file.exists())
1043                 write_cache.assert_not_called()
1044                 read_cache.assert_not_called()
1045
1046     @event_loop()
1047     def test_output_locking_when_writeback_diff(self) -> None:
1048         with cache_dir() as workspace:
1049             for tag in range(0, 4):
1050                 src = (workspace / f"test{tag}.py").resolve()
1051                 with src.open("w") as fobj:
1052                     fobj.write("print('hello')")
1053             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1054                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1055                 # this isn't quite doing what we want, but if it _isn't_
1056                 # called then we cannot be using the lock it provides
1057                 mgr.assert_called()
1058
1059     @event_loop()
1060     def test_output_locking_when_writeback_color_diff(self) -> None:
1061         with cache_dir() as workspace:
1062             for tag in range(0, 4):
1063                 src = (workspace / f"test{tag}.py").resolve()
1064                 with src.open("w") as fobj:
1065                     fobj.write("print('hello')")
1066             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1067                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1068                 # this isn't quite doing what we want, but if it _isn't_
1069                 # called then we cannot be using the lock it provides
1070                 mgr.assert_called()
1071
1072     def test_no_cache_when_stdin(self) -> None:
1073         mode = DEFAULT_MODE
1074         with cache_dir():
1075             result = CliRunner().invoke(
1076                 black.main, ["-"], input=BytesIO(b"print('hello')")
1077             )
1078             self.assertEqual(result.exit_code, 0)
1079             cache_file = black.get_cache_file(mode)
1080             self.assertFalse(cache_file.exists())
1081
1082     def test_read_cache_no_cachefile(self) -> None:
1083         mode = DEFAULT_MODE
1084         with cache_dir():
1085             self.assertEqual(black.read_cache(mode), {})
1086
1087     def test_write_cache_read_cache(self) -> None:
1088         mode = DEFAULT_MODE
1089         with cache_dir() as workspace:
1090             src = (workspace / "test.py").resolve()
1091             src.touch()
1092             black.write_cache({}, [src], mode)
1093             cache = black.read_cache(mode)
1094             self.assertIn(src, cache)
1095             self.assertEqual(cache[src], black.get_cache_info(src))
1096
1097     def test_filter_cached(self) -> None:
1098         with TemporaryDirectory() as workspace:
1099             path = Path(workspace)
1100             uncached = (path / "uncached").resolve()
1101             cached = (path / "cached").resolve()
1102             cached_but_changed = (path / "changed").resolve()
1103             uncached.touch()
1104             cached.touch()
1105             cached_but_changed.touch()
1106             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1107             todo, done = black.filter_cached(
1108                 cache, {uncached, cached, cached_but_changed}
1109             )
1110             self.assertEqual(todo, {uncached, cached_but_changed})
1111             self.assertEqual(done, {cached})
1112
1113     def test_write_cache_creates_directory_if_needed(self) -> None:
1114         mode = DEFAULT_MODE
1115         with cache_dir(exists=False) as workspace:
1116             self.assertFalse(workspace.exists())
1117             black.write_cache({}, [], mode)
1118             self.assertTrue(workspace.exists())
1119
1120     @event_loop()
1121     def test_failed_formatting_does_not_get_cached(self) -> None:
1122         mode = DEFAULT_MODE
1123         with cache_dir() as workspace, patch(
1124             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1125         ):
1126             failing = (workspace / "failing.py").resolve()
1127             with failing.open("w") as fobj:
1128                 fobj.write("not actually python")
1129             clean = (workspace / "clean.py").resolve()
1130             with clean.open("w") as fobj:
1131                 fobj.write('print("hello")\n')
1132             self.invokeBlack([str(workspace)], exit_code=123)
1133             cache = black.read_cache(mode)
1134             self.assertNotIn(failing, cache)
1135             self.assertIn(clean, cache)
1136
1137     def test_write_cache_write_fail(self) -> None:
1138         mode = DEFAULT_MODE
1139         with cache_dir(), patch.object(Path, "open") as mock:
1140             mock.side_effect = OSError
1141             black.write_cache({}, [], mode)
1142
1143     @event_loop()
1144     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1145     def test_works_in_mono_process_only_environment(self) -> None:
1146         with cache_dir() as workspace:
1147             for f in [
1148                 (workspace / "one.py").resolve(),
1149                 (workspace / "two.py").resolve(),
1150             ]:
1151                 f.write_text('print("hello")\n')
1152             self.invokeBlack([str(workspace)])
1153
1154     @event_loop()
1155     def test_check_diff_use_together(self) -> None:
1156         with cache_dir():
1157             # Files which will be reformatted.
1158             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1159             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1160             # Files which will not be reformatted.
1161             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1162             self.invokeBlack([str(src2), "--diff", "--check"])
1163             # Multi file command.
1164             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1165
1166     def test_no_files(self) -> None:
1167         with cache_dir():
1168             # Without an argument, black exits with error code 0.
1169             self.invokeBlack([])
1170
1171     def test_broken_symlink(self) -> None:
1172         with cache_dir() as workspace:
1173             symlink = workspace / "broken_link.py"
1174             try:
1175                 symlink.symlink_to("nonexistent.py")
1176             except OSError as e:
1177                 self.skipTest(f"Can't create symlinks: {e}")
1178             self.invokeBlack([str(workspace.resolve())])
1179
1180     def test_read_cache_line_lengths(self) -> None:
1181         mode = DEFAULT_MODE
1182         short_mode = replace(DEFAULT_MODE, line_length=1)
1183         with cache_dir() as workspace:
1184             path = (workspace / "file.py").resolve()
1185             path.touch()
1186             black.write_cache({}, [path], mode)
1187             one = black.read_cache(mode)
1188             self.assertIn(path, one)
1189             two = black.read_cache(short_mode)
1190             self.assertNotIn(path, two)
1191
1192     def test_single_file_force_pyi(self) -> None:
1193         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1194         contents, expected = read_data("force_pyi")
1195         with cache_dir() as workspace:
1196             path = (workspace / "file.py").resolve()
1197             with open(path, "w") as fh:
1198                 fh.write(contents)
1199             self.invokeBlack([str(path), "--pyi"])
1200             with open(path, "r") as fh:
1201                 actual = fh.read()
1202             # verify cache with --pyi is separate
1203             pyi_cache = black.read_cache(pyi_mode)
1204             self.assertIn(path, pyi_cache)
1205             normal_cache = black.read_cache(DEFAULT_MODE)
1206             self.assertNotIn(path, normal_cache)
1207         self.assertFormatEqual(expected, actual)
1208         black.assert_equivalent(contents, actual)
1209         black.assert_stable(contents, actual, pyi_mode)
1210
1211     @event_loop()
1212     def test_multi_file_force_pyi(self) -> None:
1213         reg_mode = DEFAULT_MODE
1214         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1215         contents, expected = read_data("force_pyi")
1216         with cache_dir() as workspace:
1217             paths = [
1218                 (workspace / "file1.py").resolve(),
1219                 (workspace / "file2.py").resolve(),
1220             ]
1221             for path in paths:
1222                 with open(path, "w") as fh:
1223                     fh.write(contents)
1224             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1225             for path in paths:
1226                 with open(path, "r") as fh:
1227                     actual = fh.read()
1228                 self.assertEqual(actual, expected)
1229             # verify cache with --pyi is separate
1230             pyi_cache = black.read_cache(pyi_mode)
1231             normal_cache = black.read_cache(reg_mode)
1232             for path in paths:
1233                 self.assertIn(path, pyi_cache)
1234                 self.assertNotIn(path, normal_cache)
1235
1236     def test_pipe_force_pyi(self) -> None:
1237         source, expected = read_data("force_pyi")
1238         result = CliRunner().invoke(
1239             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1240         )
1241         self.assertEqual(result.exit_code, 0)
1242         actual = result.output
1243         self.assertFormatEqual(actual, expected)
1244
1245     def test_single_file_force_py36(self) -> None:
1246         reg_mode = DEFAULT_MODE
1247         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1248         source, expected = read_data("force_py36")
1249         with cache_dir() as workspace:
1250             path = (workspace / "file.py").resolve()
1251             with open(path, "w") as fh:
1252                 fh.write(source)
1253             self.invokeBlack([str(path), *PY36_ARGS])
1254             with open(path, "r") as fh:
1255                 actual = fh.read()
1256             # verify cache with --target-version is separate
1257             py36_cache = black.read_cache(py36_mode)
1258             self.assertIn(path, py36_cache)
1259             normal_cache = black.read_cache(reg_mode)
1260             self.assertNotIn(path, normal_cache)
1261         self.assertEqual(actual, expected)
1262
1263     @event_loop()
1264     def test_multi_file_force_py36(self) -> None:
1265         reg_mode = DEFAULT_MODE
1266         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1267         source, expected = read_data("force_py36")
1268         with cache_dir() as workspace:
1269             paths = [
1270                 (workspace / "file1.py").resolve(),
1271                 (workspace / "file2.py").resolve(),
1272             ]
1273             for path in paths:
1274                 with open(path, "w") as fh:
1275                     fh.write(source)
1276             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1277             for path in paths:
1278                 with open(path, "r") as fh:
1279                     actual = fh.read()
1280                 self.assertEqual(actual, expected)
1281             # verify cache with --target-version is separate
1282             pyi_cache = black.read_cache(py36_mode)
1283             normal_cache = black.read_cache(reg_mode)
1284             for path in paths:
1285                 self.assertIn(path, pyi_cache)
1286                 self.assertNotIn(path, normal_cache)
1287
1288     def test_pipe_force_py36(self) -> None:
1289         source, expected = read_data("force_py36")
1290         result = CliRunner().invoke(
1291             black.main,
1292             ["-", "-q", "--target-version=py36"],
1293             input=BytesIO(source.encode("utf8")),
1294         )
1295         self.assertEqual(result.exit_code, 0)
1296         actual = result.output
1297         self.assertFormatEqual(actual, expected)
1298
1299     def test_include_exclude(self) -> None:
1300         path = THIS_DIR / "data" / "include_exclude_tests"
1301         include = re.compile(r"\.pyi?$")
1302         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1303         report = black.Report()
1304         gitignore = PathSpec.from_lines("gitwildmatch", [])
1305         sources: List[Path] = []
1306         expected = [
1307             Path(path / "b/dont_exclude/a.py"),
1308             Path(path / "b/dont_exclude/a.pyi"),
1309         ]
1310         this_abs = THIS_DIR.resolve()
1311         sources.extend(
1312             black.gen_python_files(
1313                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1314             )
1315         )
1316         self.assertEqual(sorted(expected), sorted(sources))
1317
1318     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1319     def test_exclude_for_issue_1572(self) -> None:
1320         # Exclude shouldn't touch files that were explicitly given to Black through the
1321         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1322         # https://github.com/psf/black/issues/1572
1323         path = THIS_DIR / "data" / "include_exclude_tests"
1324         include = ""
1325         exclude = r"/exclude/|a\.py"
1326         src = str(path / "b/exclude/a.py")
1327         report = black.Report()
1328         expected = [Path(path / "b/exclude/a.py")]
1329         sources = list(
1330             black.get_sources(
1331                 ctx=FakeContext(),
1332                 src=(src,),
1333                 quiet=True,
1334                 verbose=False,
1335                 include=include,
1336                 exclude=exclude,
1337                 force_exclude=None,
1338                 report=report,
1339             )
1340         )
1341         self.assertEqual(sorted(expected), sorted(sources))
1342
1343     def test_gitignore_exclude(self) -> None:
1344         path = THIS_DIR / "data" / "include_exclude_tests"
1345         include = re.compile(r"\.pyi?$")
1346         exclude = re.compile(r"")
1347         report = black.Report()
1348         gitignore = PathSpec.from_lines(
1349             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1350         )
1351         sources: List[Path] = []
1352         expected = [
1353             Path(path / "b/dont_exclude/a.py"),
1354             Path(path / "b/dont_exclude/a.pyi"),
1355         ]
1356         this_abs = THIS_DIR.resolve()
1357         sources.extend(
1358             black.gen_python_files(
1359                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1360             )
1361         )
1362         self.assertEqual(sorted(expected), sorted(sources))
1363
1364     def test_empty_include(self) -> None:
1365         path = THIS_DIR / "data" / "include_exclude_tests"
1366         report = black.Report()
1367         gitignore = PathSpec.from_lines("gitwildmatch", [])
1368         empty = re.compile(r"")
1369         sources: List[Path] = []
1370         expected = [
1371             Path(path / "b/exclude/a.pie"),
1372             Path(path / "b/exclude/a.py"),
1373             Path(path / "b/exclude/a.pyi"),
1374             Path(path / "b/dont_exclude/a.pie"),
1375             Path(path / "b/dont_exclude/a.py"),
1376             Path(path / "b/dont_exclude/a.pyi"),
1377             Path(path / "b/.definitely_exclude/a.pie"),
1378             Path(path / "b/.definitely_exclude/a.py"),
1379             Path(path / "b/.definitely_exclude/a.pyi"),
1380         ]
1381         this_abs = THIS_DIR.resolve()
1382         sources.extend(
1383             black.gen_python_files(
1384                 path.iterdir(),
1385                 this_abs,
1386                 empty,
1387                 re.compile(black.DEFAULT_EXCLUDES),
1388                 None,
1389                 report,
1390                 gitignore,
1391             )
1392         )
1393         self.assertEqual(sorted(expected), sorted(sources))
1394
1395     def test_empty_exclude(self) -> None:
1396         path = THIS_DIR / "data" / "include_exclude_tests"
1397         report = black.Report()
1398         gitignore = PathSpec.from_lines("gitwildmatch", [])
1399         empty = re.compile(r"")
1400         sources: List[Path] = []
1401         expected = [
1402             Path(path / "b/dont_exclude/a.py"),
1403             Path(path / "b/dont_exclude/a.pyi"),
1404             Path(path / "b/exclude/a.py"),
1405             Path(path / "b/exclude/a.pyi"),
1406             Path(path / "b/.definitely_exclude/a.py"),
1407             Path(path / "b/.definitely_exclude/a.pyi"),
1408         ]
1409         this_abs = THIS_DIR.resolve()
1410         sources.extend(
1411             black.gen_python_files(
1412                 path.iterdir(),
1413                 this_abs,
1414                 re.compile(black.DEFAULT_INCLUDES),
1415                 empty,
1416                 None,
1417                 report,
1418                 gitignore,
1419             )
1420         )
1421         self.assertEqual(sorted(expected), sorted(sources))
1422
1423     def test_invalid_include_exclude(self) -> None:
1424         for option in ["--include", "--exclude"]:
1425             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1426
1427     def test_preserves_line_endings(self) -> None:
1428         with TemporaryDirectory() as workspace:
1429             test_file = Path(workspace) / "test.py"
1430             for nl in ["\n", "\r\n"]:
1431                 contents = nl.join(["def f(  ):", "    pass"])
1432                 test_file.write_bytes(contents.encode())
1433                 ff(test_file, write_back=black.WriteBack.YES)
1434                 updated_contents: bytes = test_file.read_bytes()
1435                 self.assertIn(nl.encode(), updated_contents)
1436                 if nl == "\n":
1437                     self.assertNotIn(b"\r\n", updated_contents)
1438
1439     def test_preserves_line_endings_via_stdin(self) -> None:
1440         for nl in ["\n", "\r\n"]:
1441             contents = nl.join(["def f(  ):", "    pass"])
1442             runner = BlackRunner()
1443             result = runner.invoke(
1444                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1445             )
1446             self.assertEqual(result.exit_code, 0)
1447             output = runner.stdout_bytes
1448             self.assertIn(nl.encode("utf8"), output)
1449             if nl == "\n":
1450                 self.assertNotIn(b"\r\n", output)
1451
1452     def test_assert_equivalent_different_asts(self) -> None:
1453         with self.assertRaises(AssertionError):
1454             black.assert_equivalent("{}", "None")
1455
1456     def test_symlink_out_of_root_directory(self) -> None:
1457         path = MagicMock()
1458         root = THIS_DIR.resolve()
1459         child = MagicMock()
1460         include = re.compile(black.DEFAULT_INCLUDES)
1461         exclude = re.compile(black.DEFAULT_EXCLUDES)
1462         report = black.Report()
1463         gitignore = PathSpec.from_lines("gitwildmatch", [])
1464         # `child` should behave like a symlink which resolved path is clearly
1465         # outside of the `root` directory.
1466         path.iterdir.return_value = [child]
1467         child.resolve.return_value = Path("/a/b/c")
1468         child.as_posix.return_value = "/a/b/c"
1469         child.is_symlink.return_value = True
1470         try:
1471             list(
1472                 black.gen_python_files(
1473                     path.iterdir(), root, include, exclude, None, report, gitignore
1474                 )
1475             )
1476         except ValueError as ve:
1477             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1478         path.iterdir.assert_called_once()
1479         child.resolve.assert_called_once()
1480         child.is_symlink.assert_called_once()
1481         # `child` should behave like a strange file which resolved path is clearly
1482         # outside of the `root` directory.
1483         child.is_symlink.return_value = False
1484         with self.assertRaises(ValueError):
1485             list(
1486                 black.gen_python_files(
1487                     path.iterdir(), root, include, exclude, None, report, gitignore
1488                 )
1489             )
1490         path.iterdir.assert_called()
1491         self.assertEqual(path.iterdir.call_count, 2)
1492         child.resolve.assert_called()
1493         self.assertEqual(child.resolve.call_count, 2)
1494         child.is_symlink.assert_called()
1495         self.assertEqual(child.is_symlink.call_count, 2)
1496
1497     def test_shhh_click(self) -> None:
1498         try:
1499             from click import _unicodefun  # type: ignore
1500         except ModuleNotFoundError:
1501             self.skipTest("Incompatible Click version")
1502         if not hasattr(_unicodefun, "_verify_python3_env"):
1503             self.skipTest("Incompatible Click version")
1504         # First, let's see if Click is crashing with a preferred ASCII charset.
1505         with patch("locale.getpreferredencoding") as gpe:
1506             gpe.return_value = "ASCII"
1507             with self.assertRaises(RuntimeError):
1508                 _unicodefun._verify_python3_env()
1509         # Now, let's silence Click...
1510         black.patch_click()
1511         # ...and confirm it's silent.
1512         with patch("locale.getpreferredencoding") as gpe:
1513             gpe.return_value = "ASCII"
1514             try:
1515                 _unicodefun._verify_python3_env()
1516             except RuntimeError as re:
1517                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1518
1519     def test_root_logger_not_used_directly(self) -> None:
1520         def fail(*args: Any, **kwargs: Any) -> None:
1521             self.fail("Record created with root logger")
1522
1523         with patch.multiple(
1524             logging.root,
1525             debug=fail,
1526             info=fail,
1527             warning=fail,
1528             error=fail,
1529             critical=fail,
1530             log=fail,
1531         ):
1532             ff(THIS_FILE)
1533
1534     def test_invalid_config_return_code(self) -> None:
1535         tmp_file = Path(black.dump_to_file())
1536         try:
1537             tmp_config = Path(black.dump_to_file())
1538             tmp_config.unlink()
1539             args = ["--config", str(tmp_config), str(tmp_file)]
1540             self.invokeBlack(args, exit_code=2, ignore_config=False)
1541         finally:
1542             tmp_file.unlink()
1543
1544     def test_parse_pyproject_toml(self) -> None:
1545         test_toml_file = THIS_DIR / "test.toml"
1546         config = black.parse_pyproject_toml(str(test_toml_file))
1547         self.assertEqual(config["verbose"], 1)
1548         self.assertEqual(config["check"], "no")
1549         self.assertEqual(config["diff"], "y")
1550         self.assertEqual(config["color"], True)
1551         self.assertEqual(config["line_length"], 79)
1552         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1553         self.assertEqual(config["exclude"], r"\.pyi?$")
1554         self.assertEqual(config["include"], r"\.py?$")
1555
1556     def test_read_pyproject_toml(self) -> None:
1557         test_toml_file = THIS_DIR / "test.toml"
1558         fake_ctx = FakeContext()
1559         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1560         config = fake_ctx.default_map
1561         self.assertEqual(config["verbose"], "1")
1562         self.assertEqual(config["check"], "no")
1563         self.assertEqual(config["diff"], "y")
1564         self.assertEqual(config["color"], "True")
1565         self.assertEqual(config["line_length"], "79")
1566         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1567         self.assertEqual(config["exclude"], r"\.pyi?$")
1568         self.assertEqual(config["include"], r"\.py?$")
1569
1570     def test_find_project_root(self) -> None:
1571         with TemporaryDirectory() as workspace:
1572             root = Path(workspace)
1573             test_dir = root / "test"
1574             test_dir.mkdir()
1575
1576             src_dir = root / "src"
1577             src_dir.mkdir()
1578
1579             root_pyproject = root / "pyproject.toml"
1580             root_pyproject.touch()
1581             src_pyproject = src_dir / "pyproject.toml"
1582             src_pyproject.touch()
1583             src_python = src_dir / "foo.py"
1584             src_python.touch()
1585
1586             self.assertEqual(
1587                 black.find_project_root((src_dir, test_dir)), root.resolve()
1588             )
1589             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1590             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1591
1592     def test_bpo_33660_workaround(self) -> None:
1593         if system() == "Windows":
1594             return
1595
1596         # https://bugs.python.org/issue33660
1597
1598         old_cwd = Path.cwd()
1599         try:
1600             root = Path("/")
1601             os.chdir(str(root))
1602             path = Path("workspace") / "project"
1603             report = black.Report(verbose=True)
1604             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1605             self.assertEqual(normalized_path, "workspace/project")
1606         finally:
1607             os.chdir(str(old_cwd))
1608
1609
1610 with open(black.__file__, "r", encoding="utf-8") as _bf:
1611     black_source_lines = _bf.readlines()
1612
1613
1614 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
1615     """Show function calls `from black/__init__.py` as they happen.
1616
1617     Register this with `sys.settrace()` in a test you're debugging.
1618     """
1619     if event != "call":
1620         return tracefunc
1621
1622     stack = len(inspect.stack()) - 19
1623     stack *= 2
1624     filename = frame.f_code.co_filename
1625     lineno = frame.f_lineno
1626     func_sig_lineno = lineno - 1
1627     funcname = black_source_lines[func_sig_lineno].strip()
1628     while funcname.startswith("@"):
1629         func_sig_lineno += 1
1630         funcname = black_source_lines[func_sig_lineno].strip()
1631     if "black/__init__.py" in filename:
1632         print(f"{' ' * stack}{lineno}:{funcname}")
1633     return tracefunc
1634
1635
1636 if __name__ == "__main__":
1637     unittest.main(module="test_black")