]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

2d0a7dfd4e233ebecc906115c637e15dfcc145ba
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import inspect
5 import io
6 import logging
7 import multiprocessing
8 import os
9 import sys
10 import types
11 import unittest
12 from concurrent.futures import ThreadPoolExecutor
13 from contextlib import contextmanager
14 from dataclasses import replace
15 from io import BytesIO
16 from pathlib import Path
17 from platform import system
18 from tempfile import TemporaryDirectory
19 from typing import (
20     Any,
21     Callable,
22     Dict,
23     Iterator,
24     List,
25     Optional,
26     Sequence,
27     TypeVar,
28     Union,
29 )
30 from unittest.mock import MagicMock, patch
31
32 import click
33 import pytest
34 import re
35 from click import unstyle
36 from click.testing import CliRunner
37 from pathspec import PathSpec
38
39 import black
40 import black.files
41 from black import Feature, TargetVersion
42 from black import re_compile_maybe_verbose as compile_pattern
43 from black.cache import get_cache_file
44 from black.debug import DebugVisitor
45 from black.output import color_diff, diff
46 from black.report import Report
47
48 # Import other test classes
49 from tests.util import (
50     DATA_DIR,
51     DEFAULT_MODE,
52     DETERMINISTIC_HEADER,
53     PROJECT_ROOT,
54     PY36_VERSIONS,
55     THIS_DIR,
56     BlackBaseTestCase,
57     assert_format,
58     change_directory,
59     dump_to_stderr,
60     ff,
61     fs,
62     read_data,
63 )
64
65 THIS_FILE = Path(__file__)
66 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
67 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
68 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
69 T = TypeVar("T")
70 R = TypeVar("R")
71
72 # Match the time output in a diff, but nothing else
73 DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
74
75
76 @contextmanager
77 def cache_dir(exists: bool = True) -> Iterator[Path]:
78     with TemporaryDirectory() as workspace:
79         cache_dir = Path(workspace)
80         if not exists:
81             cache_dir = cache_dir / "new"
82         with patch("black.cache.CACHE_DIR", cache_dir):
83             yield cache_dir
84
85
86 @contextmanager
87 def event_loop() -> Iterator[None]:
88     policy = asyncio.get_event_loop_policy()
89     loop = policy.new_event_loop()
90     asyncio.set_event_loop(loop)
91     try:
92         yield
93
94     finally:
95         loop.close()
96
97
98 class FakeContext(click.Context):
99     """A fake click Context for when calling functions that need it."""
100
101     def __init__(self) -> None:
102         self.default_map: Dict[str, Any] = {}
103
104
105 class FakeParameter(click.Parameter):
106     """A fake click Parameter for when calling functions that need it."""
107
108     def __init__(self) -> None:
109         pass
110
111
112 class BlackRunner(CliRunner):
113     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
114
115     def __init__(self) -> None:
116         super().__init__(mix_stderr=False)
117
118
119 def invokeBlack(
120     args: List[str], exit_code: int = 0, ignore_config: bool = True
121 ) -> None:
122     runner = BlackRunner()
123     if ignore_config:
124         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
125     result = runner.invoke(black.main, args, catch_exceptions=False)
126     assert result.stdout_bytes is not None
127     assert result.stderr_bytes is not None
128     msg = (
129         f"Failed with args: {args}\n"
130         f"stdout: {result.stdout_bytes.decode()!r}\n"
131         f"stderr: {result.stderr_bytes.decode()!r}\n"
132         f"exception: {result.exception}"
133     )
134     assert result.exit_code == exit_code, msg
135
136
137 class BlackTestCase(BlackBaseTestCase):
138     invokeBlack = staticmethod(invokeBlack)
139
140     def test_empty_ff(self) -> None:
141         expected = ""
142         tmp_file = Path(black.dump_to_file())
143         try:
144             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
145             with open(tmp_file, encoding="utf8") as f:
146                 actual = f.read()
147         finally:
148             os.unlink(tmp_file)
149         self.assertFormatEqual(expected, actual)
150
151     def test_piping(self) -> None:
152         source, expected = read_data("src/black/__init__", data=False)
153         result = BlackRunner().invoke(
154             black.main,
155             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
156             input=BytesIO(source.encode("utf8")),
157         )
158         self.assertEqual(result.exit_code, 0)
159         self.assertFormatEqual(expected, result.output)
160         if source != result.output:
161             black.assert_equivalent(source, result.output)
162             black.assert_stable(source, result.output, DEFAULT_MODE)
163
164     def test_piping_diff(self) -> None:
165         diff_header = re.compile(
166             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
167             r"\+\d\d\d\d"
168         )
169         source, _ = read_data("expression.py")
170         expected, _ = read_data("expression.diff")
171         config = THIS_DIR / "data" / "empty_pyproject.toml"
172         args = [
173             "-",
174             "--fast",
175             f"--line-length={black.DEFAULT_LINE_LENGTH}",
176             "--diff",
177             f"--config={config}",
178         ]
179         result = BlackRunner().invoke(
180             black.main, args, input=BytesIO(source.encode("utf8"))
181         )
182         self.assertEqual(result.exit_code, 0)
183         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
184         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
185         self.assertEqual(expected, actual)
186
187     def test_piping_diff_with_color(self) -> None:
188         source, _ = read_data("expression.py")
189         config = THIS_DIR / "data" / "empty_pyproject.toml"
190         args = [
191             "-",
192             "--fast",
193             f"--line-length={black.DEFAULT_LINE_LENGTH}",
194             "--diff",
195             "--color",
196             f"--config={config}",
197         ]
198         result = BlackRunner().invoke(
199             black.main, args, input=BytesIO(source.encode("utf8"))
200         )
201         actual = result.output
202         # Again, the contents are checked in a different test, so only look for colors.
203         self.assertIn("\033[1;37m", actual)
204         self.assertIn("\033[36m", actual)
205         self.assertIn("\033[32m", actual)
206         self.assertIn("\033[31m", actual)
207         self.assertIn("\033[0m", actual)
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def _test_wip(self) -> None:
211         source, expected = read_data("wip")
212         sys.settrace(tracefunc)
213         mode = replace(
214             DEFAULT_MODE,
215             experimental_string_processing=False,
216             target_versions={black.TargetVersion.PY38},
217         )
218         actual = fs(source, mode=mode)
219         sys.settrace(None)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, black.FileMode())
223
224     @unittest.expectedFailure
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_trailing_comma_optional_parens_stability1(self) -> None:
227         source, _expected = read_data("trailing_comma_optional_parens1")
228         actual = fs(source)
229         black.assert_stable(source, actual, DEFAULT_MODE)
230
231     @unittest.expectedFailure
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_trailing_comma_optional_parens_stability2(self) -> None:
234         source, _expected = read_data("trailing_comma_optional_parens2")
235         actual = fs(source)
236         black.assert_stable(source, actual, DEFAULT_MODE)
237
238     @unittest.expectedFailure
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_trailing_comma_optional_parens_stability3(self) -> None:
241         source, _expected = read_data("trailing_comma_optional_parens3")
242         actual = fs(source)
243         black.assert_stable(source, actual, DEFAULT_MODE)
244
245     @patch("black.dump_to_file", dump_to_stderr)
246     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
247         source, _expected = read_data("trailing_comma_optional_parens1")
248         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
249         black.assert_stable(source, actual, DEFAULT_MODE)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
253         source, _expected = read_data("trailing_comma_optional_parens2")
254         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
255         black.assert_stable(source, actual, DEFAULT_MODE)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
259         source, _expected = read_data("trailing_comma_optional_parens3")
260         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
261         black.assert_stable(source, actual, DEFAULT_MODE)
262
263     def test_pep_572_version_detection(self) -> None:
264         source, _ = read_data("pep_572")
265         root = black.lib2to3_parse(source)
266         features = black.get_features_used(root)
267         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
268         versions = black.detect_target_versions(root)
269         self.assertIn(black.TargetVersion.PY38, versions)
270
271     def test_expression_ff(self) -> None:
272         source, expected = read_data("expression")
273         tmp_file = Path(black.dump_to_file(source))
274         try:
275             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
276             with open(tmp_file, encoding="utf8") as f:
277                 actual = f.read()
278         finally:
279             os.unlink(tmp_file)
280         self.assertFormatEqual(expected, actual)
281         with patch("black.dump_to_file", dump_to_stderr):
282             black.assert_equivalent(source, actual)
283             black.assert_stable(source, actual, DEFAULT_MODE)
284
285     def test_expression_diff(self) -> None:
286         source, _ = read_data("expression.py")
287         config = THIS_DIR / "data" / "empty_pyproject.toml"
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(
296                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
297             )
298             self.assertEqual(result.exit_code, 0)
299         finally:
300             os.unlink(tmp_file)
301         actual = result.output
302         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
303         if expected != actual:
304             dump = black.dump_to_file(actual)
305             msg = (
306                 "Expected diff isn't equal to the actual. If you made changes to"
307                 " expression.py and this is an anticipated difference, overwrite"
308                 f" tests/data/expression.diff with {dump}"
309             )
310             self.assertEqual(expected, actual, msg)
311
312     def test_expression_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         expected, _ = read_data("expression.diff")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
320             )
321         finally:
322             os.unlink(tmp_file)
323         actual = result.output
324         # We check the contents of the diff in `test_expression_diff`. All
325         # we need to check here is that color codes exist in the result.
326         self.assertIn("\033[1;37m", actual)
327         self.assertIn("\033[36m", actual)
328         self.assertIn("\033[32m", actual)
329         self.assertIn("\033[31m", actual)
330         self.assertIn("\033[0m", actual)
331
332     def test_detect_pos_only_arguments(self) -> None:
333         source, _ = read_data("pep_570")
334         root = black.lib2to3_parse(source)
335         features = black.get_features_used(root)
336         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
337         versions = black.detect_target_versions(root)
338         self.assertIn(black.TargetVersion.PY38, versions)
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_string_quotes(self) -> None:
342         source, expected = read_data("string_quotes")
343         mode = black.Mode(experimental_string_processing=True)
344         assert_format(source, expected, mode)
345         mode = replace(mode, string_normalization=False)
346         not_normalized = fs(source, mode=mode)
347         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
348         black.assert_equivalent(source, not_normalized)
349         black.assert_stable(source, not_normalized, mode=mode)
350
351     def test_skip_magic_trailing_comma(self) -> None:
352         source, _ = read_data("expression.py")
353         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
354         tmp_file = Path(black.dump_to_file(source))
355         diff_header = re.compile(
356             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
357             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
358         )
359         try:
360             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
361             self.assertEqual(result.exit_code, 0)
362         finally:
363             os.unlink(tmp_file)
364         actual = result.output
365         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
366         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
367         if expected != actual:
368             dump = black.dump_to_file(actual)
369             msg = (
370                 "Expected diff isn't equal to the actual. If you made changes to"
371                 " expression.py and this is an anticipated difference, overwrite"
372                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
373             )
374             self.assertEqual(expected, actual, msg)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_async_as_identifier(self) -> None:
378         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
379         source, expected = read_data("async_as_identifier")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         major, minor = sys.version_info[:2]
383         if major < 3 or (major <= 3 and minor < 7):
384             black.assert_equivalent(source, actual)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386         # ensure black can parse this when the target is 3.6
387         self.invokeBlack([str(source_path), "--target-version", "py36"])
388         # but not on 3.7, because async/await is no longer an identifier
389         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_python37(self) -> None:
393         source_path = (THIS_DIR / "data" / "python37.py").resolve()
394         source, expected = read_data("python37")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         major, minor = sys.version_info[:2]
398         if major > 3 or (major == 3 and minor >= 7):
399             black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401         # ensure black can parse this when the target is 3.7
402         self.invokeBlack([str(source_path), "--target-version", "py37"])
403         # but not on 3.6, because we use async as a reserved keyword
404         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
405
406     def test_tab_comment_indentation(self) -> None:
407         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
408         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
409         self.assertFormatEqual(contents_spc, fs(contents_spc))
410         self.assertFormatEqual(contents_spc, fs(contents_tab))
411
412         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
413         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
414         self.assertFormatEqual(contents_spc, fs(contents_spc))
415         self.assertFormatEqual(contents_spc, fs(contents_tab))
416
417         # mixed tabs and spaces (valid Python 2 code)
418         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
419         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
420         self.assertFormatEqual(contents_spc, fs(contents_spc))
421         self.assertFormatEqual(contents_spc, fs(contents_tab))
422
423         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
424         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
425         self.assertFormatEqual(contents_spc, fs(contents_spc))
426         self.assertFormatEqual(contents_spc, fs(contents_tab))
427
428     def test_report_verbose(self) -> None:
429         report = Report(verbose=True)
430         out_lines = []
431         err_lines = []
432
433         def out(msg: str, **kwargs: Any) -> None:
434             out_lines.append(msg)
435
436         def err(msg: str, **kwargs: Any) -> None:
437             err_lines.append(msg)
438
439         with patch("black.output._out", out), patch("black.output._err", err):
440             report.done(Path("f1"), black.Changed.NO)
441             self.assertEqual(len(out_lines), 1)
442             self.assertEqual(len(err_lines), 0)
443             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
444             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
445             self.assertEqual(report.return_code, 0)
446             report.done(Path("f2"), black.Changed.YES)
447             self.assertEqual(len(out_lines), 2)
448             self.assertEqual(len(err_lines), 0)
449             self.assertEqual(out_lines[-1], "reformatted f2")
450             self.assertEqual(
451                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
452             )
453             report.done(Path("f3"), black.Changed.CACHED)
454             self.assertEqual(len(out_lines), 3)
455             self.assertEqual(len(err_lines), 0)
456             self.assertEqual(
457                 out_lines[-1], "f3 wasn't modified on disk since last run."
458             )
459             self.assertEqual(
460                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
461             )
462             self.assertEqual(report.return_code, 0)
463             report.check = True
464             self.assertEqual(report.return_code, 1)
465             report.check = False
466             report.failed(Path("e1"), "boom")
467             self.assertEqual(len(out_lines), 3)
468             self.assertEqual(len(err_lines), 1)
469             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
473                 " reformat.",
474             )
475             self.assertEqual(report.return_code, 123)
476             report.done(Path("f3"), black.Changed.YES)
477             self.assertEqual(len(out_lines), 4)
478             self.assertEqual(len(err_lines), 1)
479             self.assertEqual(out_lines[-1], "reformatted f3")
480             self.assertEqual(
481                 unstyle(str(report)),
482                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
483                 " reformat.",
484             )
485             self.assertEqual(report.return_code, 123)
486             report.failed(Path("e2"), "boom")
487             self.assertEqual(len(out_lines), 4)
488             self.assertEqual(len(err_lines), 2)
489             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
490             self.assertEqual(
491                 unstyle(str(report)),
492                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
493                 " reformat.",
494             )
495             self.assertEqual(report.return_code, 123)
496             report.path_ignored(Path("wat"), "no match")
497             self.assertEqual(len(out_lines), 5)
498             self.assertEqual(len(err_lines), 2)
499             self.assertEqual(out_lines[-1], "wat ignored: no match")
500             self.assertEqual(
501                 unstyle(str(report)),
502                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
503                 " reformat.",
504             )
505             self.assertEqual(report.return_code, 123)
506             report.done(Path("f4"), black.Changed.NO)
507             self.assertEqual(len(out_lines), 6)
508             self.assertEqual(len(err_lines), 2)
509             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
510             self.assertEqual(
511                 unstyle(str(report)),
512                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
513                 " reformat.",
514             )
515             self.assertEqual(report.return_code, 123)
516             report.check = True
517             self.assertEqual(
518                 unstyle(str(report)),
519                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
520                 " would fail to reformat.",
521             )
522             report.check = False
523             report.diff = True
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
527                 " would fail to reformat.",
528             )
529
530     def test_report_quiet(self) -> None:
531         report = Report(quiet=True)
532         out_lines = []
533         err_lines = []
534
535         def out(msg: str, **kwargs: Any) -> None:
536             out_lines.append(msg)
537
538         def err(msg: str, **kwargs: Any) -> None:
539             err_lines.append(msg)
540
541         with patch("black.output._out", out), patch("black.output._err", err):
542             report.done(Path("f1"), black.Changed.NO)
543             self.assertEqual(len(out_lines), 0)
544             self.assertEqual(len(err_lines), 0)
545             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
546             self.assertEqual(report.return_code, 0)
547             report.done(Path("f2"), black.Changed.YES)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 0)
550             self.assertEqual(
551                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
552             )
553             report.done(Path("f3"), black.Changed.CACHED)
554             self.assertEqual(len(out_lines), 0)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(
557                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
558             )
559             self.assertEqual(report.return_code, 0)
560             report.check = True
561             self.assertEqual(report.return_code, 1)
562             report.check = False
563             report.failed(Path("e1"), "boom")
564             self.assertEqual(len(out_lines), 0)
565             self.assertEqual(len(err_lines), 1)
566             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
567             self.assertEqual(
568                 unstyle(str(report)),
569                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
570                 " reformat.",
571             )
572             self.assertEqual(report.return_code, 123)
573             report.done(Path("f3"), black.Changed.YES)
574             self.assertEqual(len(out_lines), 0)
575             self.assertEqual(len(err_lines), 1)
576             self.assertEqual(
577                 unstyle(str(report)),
578                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
579                 " reformat.",
580             )
581             self.assertEqual(report.return_code, 123)
582             report.failed(Path("e2"), "boom")
583             self.assertEqual(len(out_lines), 0)
584             self.assertEqual(len(err_lines), 2)
585             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
586             self.assertEqual(
587                 unstyle(str(report)),
588                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
589                 " reformat.",
590             )
591             self.assertEqual(report.return_code, 123)
592             report.path_ignored(Path("wat"), "no match")
593             self.assertEqual(len(out_lines), 0)
594             self.assertEqual(len(err_lines), 2)
595             self.assertEqual(
596                 unstyle(str(report)),
597                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
598                 " reformat.",
599             )
600             self.assertEqual(report.return_code, 123)
601             report.done(Path("f4"), black.Changed.NO)
602             self.assertEqual(len(out_lines), 0)
603             self.assertEqual(len(err_lines), 2)
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
607                 " reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.check = True
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
614                 " would fail to reformat.",
615             )
616             report.check = False
617             report.diff = True
618             self.assertEqual(
619                 unstyle(str(report)),
620                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
621                 " would fail to reformat.",
622             )
623
624     def test_report_normal(self) -> None:
625         report = black.Report()
626         out_lines = []
627         err_lines = []
628
629         def out(msg: str, **kwargs: Any) -> None:
630             out_lines.append(msg)
631
632         def err(msg: str, **kwargs: Any) -> None:
633             err_lines.append(msg)
634
635         with patch("black.output._out", out), patch("black.output._err", err):
636             report.done(Path("f1"), black.Changed.NO)
637             self.assertEqual(len(out_lines), 0)
638             self.assertEqual(len(err_lines), 0)
639             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
640             self.assertEqual(report.return_code, 0)
641             report.done(Path("f2"), black.Changed.YES)
642             self.assertEqual(len(out_lines), 1)
643             self.assertEqual(len(err_lines), 0)
644             self.assertEqual(out_lines[-1], "reformatted f2")
645             self.assertEqual(
646                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
647             )
648             report.done(Path("f3"), black.Changed.CACHED)
649             self.assertEqual(len(out_lines), 1)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(out_lines[-1], "reformatted f2")
652             self.assertEqual(
653                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
654             )
655             self.assertEqual(report.return_code, 0)
656             report.check = True
657             self.assertEqual(report.return_code, 1)
658             report.check = False
659             report.failed(Path("e1"), "boom")
660             self.assertEqual(len(out_lines), 1)
661             self.assertEqual(len(err_lines), 1)
662             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
663             self.assertEqual(
664                 unstyle(str(report)),
665                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
666                 " reformat.",
667             )
668             self.assertEqual(report.return_code, 123)
669             report.done(Path("f3"), black.Changed.YES)
670             self.assertEqual(len(out_lines), 2)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(out_lines[-1], "reformatted f3")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
676                 " reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.failed(Path("e2"), "boom")
680             self.assertEqual(len(out_lines), 2)
681             self.assertEqual(len(err_lines), 2)
682             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
686                 " reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.path_ignored(Path("wat"), "no match")
690             self.assertEqual(len(out_lines), 2)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
695                 " reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.done(Path("f4"), black.Changed.NO)
699             self.assertEqual(len(out_lines), 2)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
704                 " reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.check = True
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
711                 " would fail to reformat.",
712             )
713             report.check = False
714             report.diff = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
718                 " would fail to reformat.",
719             )
720
721     def test_lib2to3_parse(self) -> None:
722         with self.assertRaises(black.InvalidInput):
723             black.lib2to3_parse("invalid syntax")
724
725         straddling = "x + y"
726         black.lib2to3_parse(straddling)
727         black.lib2to3_parse(straddling, {TargetVersion.PY27})
728         black.lib2to3_parse(straddling, {TargetVersion.PY36})
729         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
730
731         py2_only = "print x"
732         black.lib2to3_parse(py2_only)
733         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
734         with self.assertRaises(black.InvalidInput):
735             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
736         with self.assertRaises(black.InvalidInput):
737             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
738
739         py3_only = "exec(x, end=y)"
740         black.lib2to3_parse(py3_only)
741         with self.assertRaises(black.InvalidInput):
742             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
743         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
744         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
745
746     def test_get_features_used_decorator(self) -> None:
747         # Test the feature detection of new decorator syntax
748         # since this makes some test cases of test_get_features_used()
749         # fails if it fails, this is tested first so that a useful case
750         # is identified
751         simples, relaxed = read_data("decorators")
752         # skip explanation comments at the top of the file
753         for simple_test in simples.split("##")[1:]:
754             node = black.lib2to3_parse(simple_test)
755             decorator = str(node.children[0].children[0]).strip()
756             self.assertNotIn(
757                 Feature.RELAXED_DECORATORS,
758                 black.get_features_used(node),
759                 msg=(
760                     f"decorator '{decorator}' follows python<=3.8 syntax"
761                     "but is detected as 3.9+"
762                     # f"The full node is\n{node!r}"
763                 ),
764             )
765         # skip the '# output' comment at the top of the output part
766         for relaxed_test in relaxed.split("##")[1:]:
767             node = black.lib2to3_parse(relaxed_test)
768             decorator = str(node.children[0].children[0]).strip()
769             self.assertIn(
770                 Feature.RELAXED_DECORATORS,
771                 black.get_features_used(node),
772                 msg=(
773                     f"decorator '{decorator}' uses python3.9+ syntax"
774                     "but is detected as python<=3.8"
775                     # f"The full node is\n{node!r}"
776                 ),
777             )
778
779     def test_get_features_used(self) -> None:
780         node = black.lib2to3_parse("def f(*, arg): ...\n")
781         self.assertEqual(black.get_features_used(node), set())
782         node = black.lib2to3_parse("def f(*, arg,): ...\n")
783         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
784         node = black.lib2to3_parse("f(*arg,)\n")
785         self.assertEqual(
786             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
787         )
788         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
789         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
790         node = black.lib2to3_parse("123_456\n")
791         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
792         node = black.lib2to3_parse("123456\n")
793         self.assertEqual(black.get_features_used(node), set())
794         source, expected = read_data("function")
795         node = black.lib2to3_parse(source)
796         expected_features = {
797             Feature.TRAILING_COMMA_IN_CALL,
798             Feature.TRAILING_COMMA_IN_DEF,
799             Feature.F_STRINGS,
800         }
801         self.assertEqual(black.get_features_used(node), expected_features)
802         node = black.lib2to3_parse(expected)
803         self.assertEqual(black.get_features_used(node), expected_features)
804         source, expected = read_data("expression")
805         node = black.lib2to3_parse(source)
806         self.assertEqual(black.get_features_used(node), set())
807         node = black.lib2to3_parse(expected)
808         self.assertEqual(black.get_features_used(node), set())
809         node = black.lib2to3_parse("lambda a, /, b: ...")
810         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
811         node = black.lib2to3_parse("def fn(a, /, b): ...")
812         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
813
814     def test_get_features_used_for_future_flags(self) -> None:
815         for src, features in [
816             ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
817             (
818                 "from __future__ import (other, annotations)",
819                 {Feature.FUTURE_ANNOTATIONS},
820             ),
821             ("a = 1 + 2\nfrom something import annotations", set()),
822             ("from __future__ import x, y", set()),
823         ]:
824             with self.subTest(src=src, features=features):
825                 node = black.lib2to3_parse(src)
826                 future_imports = black.get_future_imports(node)
827                 self.assertEqual(
828                     black.get_features_used(node, future_imports=future_imports),
829                     features,
830                 )
831
832     def test_get_future_imports(self) -> None:
833         node = black.lib2to3_parse("\n")
834         self.assertEqual(set(), black.get_future_imports(node))
835         node = black.lib2to3_parse("from __future__ import black\n")
836         self.assertEqual({"black"}, black.get_future_imports(node))
837         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
838         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
839         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
840         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
841         node = black.lib2to3_parse(
842             "from __future__ import multiple\nfrom __future__ import imports\n"
843         )
844         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
846         self.assertEqual({"black"}, black.get_future_imports(node))
847         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
848         self.assertEqual({"black"}, black.get_future_imports(node))
849         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
850         self.assertEqual(set(), black.get_future_imports(node))
851         node = black.lib2to3_parse("from some.module import black\n")
852         self.assertEqual(set(), black.get_future_imports(node))
853         node = black.lib2to3_parse(
854             "from __future__ import unicode_literals as _unicode_literals"
855         )
856         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
857         node = black.lib2to3_parse(
858             "from __future__ import unicode_literals as _lol, print"
859         )
860         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
861
862     @pytest.mark.incompatible_with_mypyc
863     def test_debug_visitor(self) -> None:
864         source, _ = read_data("debug_visitor.py")
865         expected, _ = read_data("debug_visitor.out")
866         out_lines = []
867         err_lines = []
868
869         def out(msg: str, **kwargs: Any) -> None:
870             out_lines.append(msg)
871
872         def err(msg: str, **kwargs: Any) -> None:
873             err_lines.append(msg)
874
875         with patch("black.debug.out", out):
876             DebugVisitor.show(source)
877         actual = "\n".join(out_lines) + "\n"
878         log_name = ""
879         if expected != actual:
880             log_name = black.dump_to_file(*out_lines)
881         self.assertEqual(
882             expected,
883             actual,
884             f"AST print out is different. Actual version dumped to {log_name}",
885         )
886
887     def test_format_file_contents(self) -> None:
888         empty = ""
889         mode = DEFAULT_MODE
890         with self.assertRaises(black.NothingChanged):
891             black.format_file_contents(empty, mode=mode, fast=False)
892         just_nl = "\n"
893         with self.assertRaises(black.NothingChanged):
894             black.format_file_contents(just_nl, mode=mode, fast=False)
895         same = "j = [1, 2, 3]\n"
896         with self.assertRaises(black.NothingChanged):
897             black.format_file_contents(same, mode=mode, fast=False)
898         different = "j = [1,2,3]"
899         expected = same
900         actual = black.format_file_contents(different, mode=mode, fast=False)
901         self.assertEqual(expected, actual)
902         invalid = "return if you can"
903         with self.assertRaises(black.InvalidInput) as e:
904             black.format_file_contents(invalid, mode=mode, fast=False)
905         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
906
907     def test_endmarker(self) -> None:
908         n = black.lib2to3_parse("\n")
909         self.assertEqual(n.type, black.syms.file_input)
910         self.assertEqual(len(n.children), 1)
911         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
912
913     @pytest.mark.incompatible_with_mypyc
914     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
915     def test_assertFormatEqual(self) -> None:
916         out_lines = []
917         err_lines = []
918
919         def out(msg: str, **kwargs: Any) -> None:
920             out_lines.append(msg)
921
922         def err(msg: str, **kwargs: Any) -> None:
923             err_lines.append(msg)
924
925         with patch("black.output._out", out), patch("black.output._err", err):
926             with self.assertRaises(AssertionError):
927                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
928
929         out_str = "".join(out_lines)
930         self.assertTrue("Expected tree:" in out_str)
931         self.assertTrue("Actual tree:" in out_str)
932         self.assertEqual("".join(err_lines), "")
933
934     @event_loop()
935     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
936     def test_works_in_mono_process_only_environment(self) -> None:
937         with cache_dir() as workspace:
938             for f in [
939                 (workspace / "one.py").resolve(),
940                 (workspace / "two.py").resolve(),
941             ]:
942                 f.write_text('print("hello")\n')
943             self.invokeBlack([str(workspace)])
944
945     @event_loop()
946     def test_check_diff_use_together(self) -> None:
947         with cache_dir():
948             # Files which will be reformatted.
949             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
950             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
951             # Files which will not be reformatted.
952             src2 = (THIS_DIR / "data" / "composition.py").resolve()
953             self.invokeBlack([str(src2), "--diff", "--check"])
954             # Multi file command.
955             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
956
957     def test_no_files(self) -> None:
958         with cache_dir():
959             # Without an argument, black exits with error code 0.
960             self.invokeBlack([])
961
962     def test_broken_symlink(self) -> None:
963         with cache_dir() as workspace:
964             symlink = workspace / "broken_link.py"
965             try:
966                 symlink.symlink_to("nonexistent.py")
967             except (OSError, NotImplementedError) as e:
968                 self.skipTest(f"Can't create symlinks: {e}")
969             self.invokeBlack([str(workspace.resolve())])
970
971     def test_single_file_force_pyi(self) -> None:
972         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
973         contents, expected = read_data("force_pyi")
974         with cache_dir() as workspace:
975             path = (workspace / "file.py").resolve()
976             with open(path, "w") as fh:
977                 fh.write(contents)
978             self.invokeBlack([str(path), "--pyi"])
979             with open(path, "r") as fh:
980                 actual = fh.read()
981             # verify cache with --pyi is separate
982             pyi_cache = black.read_cache(pyi_mode)
983             self.assertIn(str(path), pyi_cache)
984             normal_cache = black.read_cache(DEFAULT_MODE)
985             self.assertNotIn(str(path), normal_cache)
986         self.assertFormatEqual(expected, actual)
987         black.assert_equivalent(contents, actual)
988         black.assert_stable(contents, actual, pyi_mode)
989
990     @event_loop()
991     def test_multi_file_force_pyi(self) -> None:
992         reg_mode = DEFAULT_MODE
993         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
994         contents, expected = read_data("force_pyi")
995         with cache_dir() as workspace:
996             paths = [
997                 (workspace / "file1.py").resolve(),
998                 (workspace / "file2.py").resolve(),
999             ]
1000             for path in paths:
1001                 with open(path, "w") as fh:
1002                     fh.write(contents)
1003             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1004             for path in paths:
1005                 with open(path, "r") as fh:
1006                     actual = fh.read()
1007                 self.assertEqual(actual, expected)
1008             # verify cache with --pyi is separate
1009             pyi_cache = black.read_cache(pyi_mode)
1010             normal_cache = black.read_cache(reg_mode)
1011             for path in paths:
1012                 self.assertIn(str(path), pyi_cache)
1013                 self.assertNotIn(str(path), normal_cache)
1014
1015     def test_pipe_force_pyi(self) -> None:
1016         source, expected = read_data("force_pyi")
1017         result = CliRunner().invoke(
1018             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1019         )
1020         self.assertEqual(result.exit_code, 0)
1021         actual = result.output
1022         self.assertFormatEqual(actual, expected)
1023
1024     def test_single_file_force_py36(self) -> None:
1025         reg_mode = DEFAULT_MODE
1026         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1027         source, expected = read_data("force_py36")
1028         with cache_dir() as workspace:
1029             path = (workspace / "file.py").resolve()
1030             with open(path, "w") as fh:
1031                 fh.write(source)
1032             self.invokeBlack([str(path), *PY36_ARGS])
1033             with open(path, "r") as fh:
1034                 actual = fh.read()
1035             # verify cache with --target-version is separate
1036             py36_cache = black.read_cache(py36_mode)
1037             self.assertIn(str(path), py36_cache)
1038             normal_cache = black.read_cache(reg_mode)
1039             self.assertNotIn(str(path), normal_cache)
1040         self.assertEqual(actual, expected)
1041
1042     @event_loop()
1043     def test_multi_file_force_py36(self) -> None:
1044         reg_mode = DEFAULT_MODE
1045         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1046         source, expected = read_data("force_py36")
1047         with cache_dir() as workspace:
1048             paths = [
1049                 (workspace / "file1.py").resolve(),
1050                 (workspace / "file2.py").resolve(),
1051             ]
1052             for path in paths:
1053                 with open(path, "w") as fh:
1054                     fh.write(source)
1055             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1056             for path in paths:
1057                 with open(path, "r") as fh:
1058                     actual = fh.read()
1059                 self.assertEqual(actual, expected)
1060             # verify cache with --target-version is separate
1061             pyi_cache = black.read_cache(py36_mode)
1062             normal_cache = black.read_cache(reg_mode)
1063             for path in paths:
1064                 self.assertIn(str(path), pyi_cache)
1065                 self.assertNotIn(str(path), normal_cache)
1066
1067     def test_pipe_force_py36(self) -> None:
1068         source, expected = read_data("force_py36")
1069         result = CliRunner().invoke(
1070             black.main,
1071             ["-", "-q", "--target-version=py36"],
1072             input=BytesIO(source.encode("utf8")),
1073         )
1074         self.assertEqual(result.exit_code, 0)
1075         actual = result.output
1076         self.assertFormatEqual(actual, expected)
1077
1078     @pytest.mark.incompatible_with_mypyc
1079     def test_reformat_one_with_stdin(self) -> None:
1080         with patch(
1081             "black.format_stdin_to_stdout",
1082             return_value=lambda *args, **kwargs: black.Changed.YES,
1083         ) as fsts:
1084             report = MagicMock()
1085             path = Path("-")
1086             black.reformat_one(
1087                 path,
1088                 fast=True,
1089                 write_back=black.WriteBack.YES,
1090                 mode=DEFAULT_MODE,
1091                 report=report,
1092             )
1093             fsts.assert_called_once()
1094             report.done.assert_called_with(path, black.Changed.YES)
1095
1096     @pytest.mark.incompatible_with_mypyc
1097     def test_reformat_one_with_stdin_filename(self) -> None:
1098         with patch(
1099             "black.format_stdin_to_stdout",
1100             return_value=lambda *args, **kwargs: black.Changed.YES,
1101         ) as fsts:
1102             report = MagicMock()
1103             p = "foo.py"
1104             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1105             expected = Path(p)
1106             black.reformat_one(
1107                 path,
1108                 fast=True,
1109                 write_back=black.WriteBack.YES,
1110                 mode=DEFAULT_MODE,
1111                 report=report,
1112             )
1113             fsts.assert_called_once_with(
1114                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1115             )
1116             # __BLACK_STDIN_FILENAME__ should have been stripped
1117             report.done.assert_called_with(expected, black.Changed.YES)
1118
1119     @pytest.mark.incompatible_with_mypyc
1120     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1121         with patch(
1122             "black.format_stdin_to_stdout",
1123             return_value=lambda *args, **kwargs: black.Changed.YES,
1124         ) as fsts:
1125             report = MagicMock()
1126             p = "foo.pyi"
1127             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1128             expected = Path(p)
1129             black.reformat_one(
1130                 path,
1131                 fast=True,
1132                 write_back=black.WriteBack.YES,
1133                 mode=DEFAULT_MODE,
1134                 report=report,
1135             )
1136             fsts.assert_called_once_with(
1137                 fast=True,
1138                 write_back=black.WriteBack.YES,
1139                 mode=replace(DEFAULT_MODE, is_pyi=True),
1140             )
1141             # __BLACK_STDIN_FILENAME__ should have been stripped
1142             report.done.assert_called_with(expected, black.Changed.YES)
1143
1144     @pytest.mark.incompatible_with_mypyc
1145     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
1146         with patch(
1147             "black.format_stdin_to_stdout",
1148             return_value=lambda *args, **kwargs: black.Changed.YES,
1149         ) as fsts:
1150             report = MagicMock()
1151             p = "foo.ipynb"
1152             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1153             expected = Path(p)
1154             black.reformat_one(
1155                 path,
1156                 fast=True,
1157                 write_back=black.WriteBack.YES,
1158                 mode=DEFAULT_MODE,
1159                 report=report,
1160             )
1161             fsts.assert_called_once_with(
1162                 fast=True,
1163                 write_back=black.WriteBack.YES,
1164                 mode=replace(DEFAULT_MODE, is_ipynb=True),
1165             )
1166             # __BLACK_STDIN_FILENAME__ should have been stripped
1167             report.done.assert_called_with(expected, black.Changed.YES)
1168
1169     @pytest.mark.incompatible_with_mypyc
1170     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1171         with patch(
1172             "black.format_stdin_to_stdout",
1173             return_value=lambda *args, **kwargs: black.Changed.YES,
1174         ) as fsts:
1175             report = MagicMock()
1176             # Even with an existing file, since we are forcing stdin, black
1177             # should output to stdout and not modify the file inplace
1178             p = Path(str(THIS_DIR / "data/collections.py"))
1179             # Make sure is_file actually returns True
1180             self.assertTrue(p.is_file())
1181             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1182             expected = Path(p)
1183             black.reformat_one(
1184                 path,
1185                 fast=True,
1186                 write_back=black.WriteBack.YES,
1187                 mode=DEFAULT_MODE,
1188                 report=report,
1189             )
1190             fsts.assert_called_once()
1191             # __BLACK_STDIN_FILENAME__ should have been stripped
1192             report.done.assert_called_with(expected, black.Changed.YES)
1193
1194     def test_reformat_one_with_stdin_empty(self) -> None:
1195         output = io.StringIO()
1196         with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
1197             try:
1198                 black.format_stdin_to_stdout(
1199                     fast=True,
1200                     content="",
1201                     write_back=black.WriteBack.YES,
1202                     mode=DEFAULT_MODE,
1203                 )
1204             except io.UnsupportedOperation:
1205                 pass  # StringIO does not support detach
1206             assert output.getvalue() == ""
1207
1208     def test_invalid_cli_regex(self) -> None:
1209         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1210             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1211
1212     def test_required_version_matches_version(self) -> None:
1213         self.invokeBlack(
1214             ["--required-version", black.__version__], exit_code=0, ignore_config=True
1215         )
1216
1217     def test_required_version_does_not_match_version(self) -> None:
1218         self.invokeBlack(
1219             ["--required-version", "20.99b"], exit_code=1, ignore_config=True
1220         )
1221
1222     def test_preserves_line_endings(self) -> None:
1223         with TemporaryDirectory() as workspace:
1224             test_file = Path(workspace) / "test.py"
1225             for nl in ["\n", "\r\n"]:
1226                 contents = nl.join(["def f(  ):", "    pass"])
1227                 test_file.write_bytes(contents.encode())
1228                 ff(test_file, write_back=black.WriteBack.YES)
1229                 updated_contents: bytes = test_file.read_bytes()
1230                 self.assertIn(nl.encode(), updated_contents)
1231                 if nl == "\n":
1232                     self.assertNotIn(b"\r\n", updated_contents)
1233
1234     def test_preserves_line_endings_via_stdin(self) -> None:
1235         for nl in ["\n", "\r\n"]:
1236             contents = nl.join(["def f(  ):", "    pass"])
1237             runner = BlackRunner()
1238             result = runner.invoke(
1239                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1240             )
1241             self.assertEqual(result.exit_code, 0)
1242             output = result.stdout_bytes
1243             self.assertIn(nl.encode("utf8"), output)
1244             if nl == "\n":
1245                 self.assertNotIn(b"\r\n", output)
1246
1247     def test_assert_equivalent_different_asts(self) -> None:
1248         with self.assertRaises(AssertionError):
1249             black.assert_equivalent("{}", "None")
1250
1251     def test_shhh_click(self) -> None:
1252         try:
1253             from click import _unicodefun
1254         except ModuleNotFoundError:
1255             self.skipTest("Incompatible Click version")
1256         if not hasattr(_unicodefun, "_verify_python3_env"):
1257             self.skipTest("Incompatible Click version")
1258         # First, let's see if Click is crashing with a preferred ASCII charset.
1259         with patch("locale.getpreferredencoding") as gpe:
1260             gpe.return_value = "ASCII"
1261             with self.assertRaises(RuntimeError):
1262                 _unicodefun._verify_python3_env()  # type: ignore
1263         # Now, let's silence Click...
1264         black.patch_click()
1265         # ...and confirm it's silent.
1266         with patch("locale.getpreferredencoding") as gpe:
1267             gpe.return_value = "ASCII"
1268             try:
1269                 _unicodefun._verify_python3_env()  # type: ignore
1270             except RuntimeError as re:
1271                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1272
1273     def test_root_logger_not_used_directly(self) -> None:
1274         def fail(*args: Any, **kwargs: Any) -> None:
1275             self.fail("Record created with root logger")
1276
1277         with patch.multiple(
1278             logging.root,
1279             debug=fail,
1280             info=fail,
1281             warning=fail,
1282             error=fail,
1283             critical=fail,
1284             log=fail,
1285         ):
1286             ff(THIS_DIR / "util.py")
1287
1288     def test_invalid_config_return_code(self) -> None:
1289         tmp_file = Path(black.dump_to_file())
1290         try:
1291             tmp_config = Path(black.dump_to_file())
1292             tmp_config.unlink()
1293             args = ["--config", str(tmp_config), str(tmp_file)]
1294             self.invokeBlack(args, exit_code=2, ignore_config=False)
1295         finally:
1296             tmp_file.unlink()
1297
1298     def test_parse_pyproject_toml(self) -> None:
1299         test_toml_file = THIS_DIR / "test.toml"
1300         config = black.parse_pyproject_toml(str(test_toml_file))
1301         self.assertEqual(config["verbose"], 1)
1302         self.assertEqual(config["check"], "no")
1303         self.assertEqual(config["diff"], "y")
1304         self.assertEqual(config["color"], True)
1305         self.assertEqual(config["line_length"], 79)
1306         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1307         self.assertEqual(config["exclude"], r"\.pyi?$")
1308         self.assertEqual(config["include"], r"\.py?$")
1309
1310     def test_read_pyproject_toml(self) -> None:
1311         test_toml_file = THIS_DIR / "test.toml"
1312         fake_ctx = FakeContext()
1313         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1314         config = fake_ctx.default_map
1315         self.assertEqual(config["verbose"], "1")
1316         self.assertEqual(config["check"], "no")
1317         self.assertEqual(config["diff"], "y")
1318         self.assertEqual(config["color"], "True")
1319         self.assertEqual(config["line_length"], "79")
1320         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1321         self.assertEqual(config["exclude"], r"\.pyi?$")
1322         self.assertEqual(config["include"], r"\.py?$")
1323
1324     @pytest.mark.incompatible_with_mypyc
1325     def test_find_project_root(self) -> None:
1326         with TemporaryDirectory() as workspace:
1327             root = Path(workspace)
1328             test_dir = root / "test"
1329             test_dir.mkdir()
1330
1331             src_dir = root / "src"
1332             src_dir.mkdir()
1333
1334             root_pyproject = root / "pyproject.toml"
1335             root_pyproject.touch()
1336             src_pyproject = src_dir / "pyproject.toml"
1337             src_pyproject.touch()
1338             src_python = src_dir / "foo.py"
1339             src_python.touch()
1340
1341             self.assertEqual(
1342                 black.find_project_root((src_dir, test_dir)), root.resolve()
1343             )
1344             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1345             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1346
1347     @patch(
1348         "black.files.find_user_pyproject_toml",
1349         black.files.find_user_pyproject_toml.__wrapped__,
1350     )
1351     def test_find_user_pyproject_toml_linux(self) -> None:
1352         if system() == "Windows":
1353             return
1354
1355         # Test if XDG_CONFIG_HOME is checked
1356         with TemporaryDirectory() as workspace:
1357             tmp_user_config = Path(workspace) / "black"
1358             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1359                 self.assertEqual(
1360                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1361                 )
1362
1363         # Test fallback for XDG_CONFIG_HOME
1364         with patch.dict("os.environ"):
1365             os.environ.pop("XDG_CONFIG_HOME", None)
1366             fallback_user_config = Path("~/.config").expanduser() / "black"
1367             self.assertEqual(
1368                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1369             )
1370
1371     def test_find_user_pyproject_toml_windows(self) -> None:
1372         if system() != "Windows":
1373             return
1374
1375         user_config_path = Path.home() / ".black"
1376         self.assertEqual(
1377             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1378         )
1379
1380     def test_bpo_33660_workaround(self) -> None:
1381         if system() == "Windows":
1382             return
1383
1384         # https://bugs.python.org/issue33660
1385         root = Path("/")
1386         with change_directory(root):
1387             path = Path("workspace") / "project"
1388             report = black.Report(verbose=True)
1389             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1390             self.assertEqual(normalized_path, "workspace/project")
1391
1392     def test_newline_comment_interaction(self) -> None:
1393         source = "class A:\\\r\n# type: ignore\n pass\n"
1394         output = black.format_str(source, mode=DEFAULT_MODE)
1395         black.assert_stable(source, output, mode=DEFAULT_MODE)
1396
1397     def test_bpo_2142_workaround(self) -> None:
1398
1399         # https://bugs.python.org/issue2142
1400
1401         source, _ = read_data("missing_final_newline.py")
1402         # read_data adds a trailing newline
1403         source = source.rstrip()
1404         expected, _ = read_data("missing_final_newline.diff")
1405         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
1406         diff_header = re.compile(
1407             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
1408             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
1409         )
1410         try:
1411             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
1412             self.assertEqual(result.exit_code, 0)
1413         finally:
1414             os.unlink(tmp_file)
1415         actual = result.output
1416         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
1417         self.assertEqual(actual, expected)
1418
1419     @pytest.mark.python2
1420     def test_docstring_reformat_for_py27(self) -> None:
1421         """
1422         Check that stripping trailing whitespace from Python 2 docstrings
1423         doesn't trigger a "not equivalent to source" error
1424         """
1425         source = (
1426             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
1427         )
1428         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
1429
1430         result = BlackRunner().invoke(
1431             black.main,
1432             ["-", "-q", "--target-version=py27"],
1433             input=BytesIO(source),
1434         )
1435
1436         self.assertEqual(result.exit_code, 0)
1437         actual = result.stdout
1438         self.assertFormatEqual(actual, expected)
1439
1440     @staticmethod
1441     def compare_results(
1442         result: click.testing.Result, expected_value: str, expected_exit_code: int
1443     ) -> None:
1444         """Helper method to test the value and exit code of a click Result."""
1445         assert (
1446             result.output == expected_value
1447         ), "The output did not match the expected value."
1448         assert result.exit_code == expected_exit_code, "The exit code is incorrect."
1449
1450     def test_code_option(self) -> None:
1451         """Test the code option with no changes."""
1452         code = 'print("Hello world")\n'
1453         args = ["--code", code]
1454         result = CliRunner().invoke(black.main, args)
1455
1456         self.compare_results(result, code, 0)
1457
1458     def test_code_option_changed(self) -> None:
1459         """Test the code option when changes are required."""
1460         code = "print('hello world')"
1461         formatted = black.format_str(code, mode=DEFAULT_MODE)
1462
1463         args = ["--code", code]
1464         result = CliRunner().invoke(black.main, args)
1465
1466         self.compare_results(result, formatted, 0)
1467
1468     def test_code_option_check(self) -> None:
1469         """Test the code option when check is passed."""
1470         args = ["--check", "--code", 'print("Hello world")\n']
1471         result = CliRunner().invoke(black.main, args)
1472         self.compare_results(result, "", 0)
1473
1474     def test_code_option_check_changed(self) -> None:
1475         """Test the code option when changes are required, and check is passed."""
1476         args = ["--check", "--code", "print('hello world')"]
1477         result = CliRunner().invoke(black.main, args)
1478         self.compare_results(result, "", 1)
1479
1480     def test_code_option_diff(self) -> None:
1481         """Test the code option when diff is passed."""
1482         code = "print('hello world')"
1483         formatted = black.format_str(code, mode=DEFAULT_MODE)
1484         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1485
1486         args = ["--diff", "--code", code]
1487         result = CliRunner().invoke(black.main, args)
1488
1489         # Remove time from diff
1490         output = DIFF_TIME.sub("", result.output)
1491
1492         assert output == result_diff, "The output did not match the expected value."
1493         assert result.exit_code == 0, "The exit code is incorrect."
1494
1495     def test_code_option_color_diff(self) -> None:
1496         """Test the code option when color and diff are passed."""
1497         code = "print('hello world')"
1498         formatted = black.format_str(code, mode=DEFAULT_MODE)
1499
1500         result_diff = diff(code, formatted, "STDIN", "STDOUT")
1501         result_diff = color_diff(result_diff)
1502
1503         args = ["--diff", "--color", "--code", code]
1504         result = CliRunner().invoke(black.main, args)
1505
1506         # Remove time from diff
1507         output = DIFF_TIME.sub("", result.output)
1508
1509         assert output == result_diff, "The output did not match the expected value."
1510         assert result.exit_code == 0, "The exit code is incorrect."
1511
1512     @pytest.mark.incompatible_with_mypyc
1513     def test_code_option_safe(self) -> None:
1514         """Test that the code option throws an error when the sanity checks fail."""
1515         # Patch black.assert_equivalent to ensure the sanity checks fail
1516         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1517             code = 'print("Hello world")'
1518             error_msg = f"{code}\nerror: cannot format <string>: \n"
1519
1520             args = ["--safe", "--code", code]
1521             result = CliRunner().invoke(black.main, args)
1522
1523             self.compare_results(result, error_msg, 123)
1524
1525     def test_code_option_fast(self) -> None:
1526         """Test that the code option ignores errors when the sanity checks fail."""
1527         # Patch black.assert_equivalent to ensure the sanity checks fail
1528         with patch.object(black, "assert_equivalent", side_effect=AssertionError):
1529             code = 'print("Hello world")'
1530             formatted = black.format_str(code, mode=DEFAULT_MODE)
1531
1532             args = ["--fast", "--code", code]
1533             result = CliRunner().invoke(black.main, args)
1534
1535             self.compare_results(result, formatted, 0)
1536
1537     @pytest.mark.incompatible_with_mypyc
1538     def test_code_option_config(self) -> None:
1539         """
1540         Test that the code option finds the pyproject.toml in the current directory.
1541         """
1542         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1543             args = ["--code", "print"]
1544             # This is the only directory known to contain a pyproject.toml
1545             with change_directory(PROJECT_ROOT):
1546                 CliRunner().invoke(black.main, args)
1547                 pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
1548
1549             assert (
1550                 len(parse.mock_calls) >= 1
1551             ), "Expected config parse to be called with the current directory."
1552
1553             _, call_args, _ = parse.mock_calls[0]
1554             assert (
1555                 call_args[0].lower() == str(pyproject_path).lower()
1556             ), "Incorrect config loaded."
1557
1558     @pytest.mark.incompatible_with_mypyc
1559     def test_code_option_parent_config(self) -> None:
1560         """
1561         Test that the code option finds the pyproject.toml in the parent directory.
1562         """
1563         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
1564             with change_directory(THIS_DIR):
1565                 args = ["--code", "print"]
1566                 CliRunner().invoke(black.main, args)
1567
1568                 pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
1569                 assert (
1570                     len(parse.mock_calls) >= 1
1571                 ), "Expected config parse to be called with the current directory."
1572
1573                 _, call_args, _ = parse.mock_calls[0]
1574                 assert (
1575                     call_args[0].lower() == str(pyproject_path).lower()
1576                 ), "Incorrect config loaded."
1577
1578     def test_for_handled_unexpected_eof_error(self) -> None:
1579         """
1580         Test that an unexpected EOF SyntaxError is nicely presented.
1581         """
1582         with pytest.raises(black.parsing.InvalidInput) as exc_info:
1583             black.lib2to3_parse("print(", {})
1584
1585         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
1586
1587
1588 class TestCaching:
1589     def test_cache_broken_file(self) -> None:
1590         mode = DEFAULT_MODE
1591         with cache_dir() as workspace:
1592             cache_file = get_cache_file(mode)
1593             cache_file.write_text("this is not a pickle")
1594             assert black.read_cache(mode) == {}
1595             src = (workspace / "test.py").resolve()
1596             src.write_text("print('hello')")
1597             invokeBlack([str(src)])
1598             cache = black.read_cache(mode)
1599             assert str(src) in cache
1600
1601     def test_cache_single_file_already_cached(self) -> None:
1602         mode = DEFAULT_MODE
1603         with cache_dir() as workspace:
1604             src = (workspace / "test.py").resolve()
1605             src.write_text("print('hello')")
1606             black.write_cache({}, [src], mode)
1607             invokeBlack([str(src)])
1608             assert src.read_text() == "print('hello')"
1609
1610     @event_loop()
1611     def test_cache_multiple_files(self) -> None:
1612         mode = DEFAULT_MODE
1613         with cache_dir() as workspace, patch(
1614             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1615         ):
1616             one = (workspace / "one.py").resolve()
1617             with one.open("w") as fobj:
1618                 fobj.write("print('hello')")
1619             two = (workspace / "two.py").resolve()
1620             with two.open("w") as fobj:
1621                 fobj.write("print('hello')")
1622             black.write_cache({}, [one], mode)
1623             invokeBlack([str(workspace)])
1624             with one.open("r") as fobj:
1625                 assert fobj.read() == "print('hello')"
1626             with two.open("r") as fobj:
1627                 assert fobj.read() == 'print("hello")\n'
1628             cache = black.read_cache(mode)
1629             assert str(one) in cache
1630             assert str(two) in cache
1631
1632     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1633     def test_no_cache_when_writeback_diff(self, color: bool) -> None:
1634         mode = DEFAULT_MODE
1635         with cache_dir() as workspace:
1636             src = (workspace / "test.py").resolve()
1637             with src.open("w") as fobj:
1638                 fobj.write("print('hello')")
1639             with patch("black.read_cache") as read_cache, patch(
1640                 "black.write_cache"
1641             ) as write_cache:
1642                 cmd = [str(src), "--diff"]
1643                 if color:
1644                     cmd.append("--color")
1645                 invokeBlack(cmd)
1646                 cache_file = get_cache_file(mode)
1647                 assert cache_file.exists() is False
1648                 write_cache.assert_not_called()
1649                 read_cache.assert_not_called()
1650
1651     @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
1652     @event_loop()
1653     def test_output_locking_when_writeback_diff(self, color: bool) -> None:
1654         with cache_dir() as workspace:
1655             for tag in range(0, 4):
1656                 src = (workspace / f"test{tag}.py").resolve()
1657                 with src.open("w") as fobj:
1658                     fobj.write("print('hello')")
1659             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1660                 cmd = ["--diff", str(workspace)]
1661                 if color:
1662                     cmd.append("--color")
1663                 invokeBlack(cmd, exit_code=0)
1664                 # this isn't quite doing what we want, but if it _isn't_
1665                 # called then we cannot be using the lock it provides
1666                 mgr.assert_called()
1667
1668     def test_no_cache_when_stdin(self) -> None:
1669         mode = DEFAULT_MODE
1670         with cache_dir():
1671             result = CliRunner().invoke(
1672                 black.main, ["-"], input=BytesIO(b"print('hello')")
1673             )
1674             assert not result.exit_code
1675             cache_file = get_cache_file(mode)
1676             assert not cache_file.exists()
1677
1678     def test_read_cache_no_cachefile(self) -> None:
1679         mode = DEFAULT_MODE
1680         with cache_dir():
1681             assert black.read_cache(mode) == {}
1682
1683     def test_write_cache_read_cache(self) -> None:
1684         mode = DEFAULT_MODE
1685         with cache_dir() as workspace:
1686             src = (workspace / "test.py").resolve()
1687             src.touch()
1688             black.write_cache({}, [src], mode)
1689             cache = black.read_cache(mode)
1690             assert str(src) in cache
1691             assert cache[str(src)] == black.get_cache_info(src)
1692
1693     def test_filter_cached(self) -> None:
1694         with TemporaryDirectory() as workspace:
1695             path = Path(workspace)
1696             uncached = (path / "uncached").resolve()
1697             cached = (path / "cached").resolve()
1698             cached_but_changed = (path / "changed").resolve()
1699             uncached.touch()
1700             cached.touch()
1701             cached_but_changed.touch()
1702             cache = {
1703                 str(cached): black.get_cache_info(cached),
1704                 str(cached_but_changed): (0.0, 0),
1705             }
1706             todo, done = black.filter_cached(
1707                 cache, {uncached, cached, cached_but_changed}
1708             )
1709             assert todo == {uncached, cached_but_changed}
1710             assert done == {cached}
1711
1712     def test_write_cache_creates_directory_if_needed(self) -> None:
1713         mode = DEFAULT_MODE
1714         with cache_dir(exists=False) as workspace:
1715             assert not workspace.exists()
1716             black.write_cache({}, [], mode)
1717             assert workspace.exists()
1718
1719     @event_loop()
1720     def test_failed_formatting_does_not_get_cached(self) -> None:
1721         mode = DEFAULT_MODE
1722         with cache_dir() as workspace, patch(
1723             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1724         ):
1725             failing = (workspace / "failing.py").resolve()
1726             with failing.open("w") as fobj:
1727                 fobj.write("not actually python")
1728             clean = (workspace / "clean.py").resolve()
1729             with clean.open("w") as fobj:
1730                 fobj.write('print("hello")\n')
1731             invokeBlack([str(workspace)], exit_code=123)
1732             cache = black.read_cache(mode)
1733             assert str(failing) not in cache
1734             assert str(clean) in cache
1735
1736     def test_write_cache_write_fail(self) -> None:
1737         mode = DEFAULT_MODE
1738         with cache_dir(), patch.object(Path, "open") as mock:
1739             mock.side_effect = OSError
1740             black.write_cache({}, [], mode)
1741
1742     def test_read_cache_line_lengths(self) -> None:
1743         mode = DEFAULT_MODE
1744         short_mode = replace(DEFAULT_MODE, line_length=1)
1745         with cache_dir() as workspace:
1746             path = (workspace / "file.py").resolve()
1747             path.touch()
1748             black.write_cache({}, [path], mode)
1749             one = black.read_cache(mode)
1750             assert str(path) in one
1751             two = black.read_cache(short_mode)
1752             assert str(path) not in two
1753
1754
1755 def assert_collected_sources(
1756     src: Sequence[Union[str, Path]],
1757     expected: Sequence[Union[str, Path]],
1758     *,
1759     exclude: Optional[str] = None,
1760     include: Optional[str] = None,
1761     extend_exclude: Optional[str] = None,
1762     force_exclude: Optional[str] = None,
1763     stdin_filename: Optional[str] = None,
1764 ) -> None:
1765     gs_src = tuple(str(Path(s)) for s in src)
1766     gs_expected = [Path(s) for s in expected]
1767     gs_exclude = None if exclude is None else compile_pattern(exclude)
1768     gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
1769     gs_extend_exclude = (
1770         None if extend_exclude is None else compile_pattern(extend_exclude)
1771     )
1772     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
1773     collected = black.get_sources(
1774         ctx=FakeContext(),
1775         src=gs_src,
1776         quiet=False,
1777         verbose=False,
1778         include=gs_include,
1779         exclude=gs_exclude,
1780         extend_exclude=gs_extend_exclude,
1781         force_exclude=gs_force_exclude,
1782         report=black.Report(),
1783         stdin_filename=stdin_filename,
1784     )
1785     assert sorted(collected) == sorted(gs_expected)
1786
1787
1788 class TestFileCollection:
1789     def test_include_exclude(self) -> None:
1790         path = THIS_DIR / "data" / "include_exclude_tests"
1791         src = [path]
1792         expected = [
1793             Path(path / "b/dont_exclude/a.py"),
1794             Path(path / "b/dont_exclude/a.pyi"),
1795         ]
1796         assert_collected_sources(
1797             src,
1798             expected,
1799             include=r"\.pyi?$",
1800             exclude=r"/exclude/|/\.definitely_exclude/",
1801         )
1802
1803     def test_gitignore_used_as_default(self) -> None:
1804         base = Path(DATA_DIR / "include_exclude_tests")
1805         expected = [
1806             base / "b/.definitely_exclude/a.py",
1807             base / "b/.definitely_exclude/a.pyi",
1808         ]
1809         src = [base / "b/"]
1810         assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
1811
1812     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1813     def test_exclude_for_issue_1572(self) -> None:
1814         # Exclude shouldn't touch files that were explicitly given to Black through the
1815         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1816         # https://github.com/psf/black/issues/1572
1817         path = DATA_DIR / "include_exclude_tests"
1818         src = [path / "b/exclude/a.py"]
1819         expected = [path / "b/exclude/a.py"]
1820         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
1821
1822     def test_gitignore_exclude(self) -> None:
1823         path = THIS_DIR / "data" / "include_exclude_tests"
1824         include = re.compile(r"\.pyi?$")
1825         exclude = re.compile(r"")
1826         report = black.Report()
1827         gitignore = PathSpec.from_lines(
1828             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1829         )
1830         sources: List[Path] = []
1831         expected = [
1832             Path(path / "b/dont_exclude/a.py"),
1833             Path(path / "b/dont_exclude/a.pyi"),
1834         ]
1835         this_abs = THIS_DIR.resolve()
1836         sources.extend(
1837             black.gen_python_files(
1838                 path.iterdir(),
1839                 this_abs,
1840                 include,
1841                 exclude,
1842                 None,
1843                 None,
1844                 report,
1845                 gitignore,
1846                 verbose=False,
1847                 quiet=False,
1848             )
1849         )
1850         assert sorted(expected) == sorted(sources)
1851
1852     def test_nested_gitignore(self) -> None:
1853         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1854         include = re.compile(r"\.pyi?$")
1855         exclude = re.compile(r"")
1856         root_gitignore = black.files.get_gitignore(path)
1857         report = black.Report()
1858         expected: List[Path] = [
1859             Path(path / "x.py"),
1860             Path(path / "root/b.py"),
1861             Path(path / "root/c.py"),
1862             Path(path / "root/child/c.py"),
1863         ]
1864         this_abs = THIS_DIR.resolve()
1865         sources = list(
1866             black.gen_python_files(
1867                 path.iterdir(),
1868                 this_abs,
1869                 include,
1870                 exclude,
1871                 None,
1872                 None,
1873                 report,
1874                 root_gitignore,
1875                 verbose=False,
1876                 quiet=False,
1877             )
1878         )
1879         assert sorted(expected) == sorted(sources)
1880
1881     def test_invalid_gitignore(self) -> None:
1882         path = THIS_DIR / "data" / "invalid_gitignore_tests"
1883         empty_config = path / "pyproject.toml"
1884         result = BlackRunner().invoke(
1885             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1886         )
1887         assert result.exit_code == 1
1888         assert result.stderr_bytes is not None
1889
1890         gitignore = path / ".gitignore"
1891         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1892
1893     def test_invalid_nested_gitignore(self) -> None:
1894         path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
1895         empty_config = path / "pyproject.toml"
1896         result = BlackRunner().invoke(
1897             black.main, ["--verbose", "--config", str(empty_config), str(path)]
1898         )
1899         assert result.exit_code == 1
1900         assert result.stderr_bytes is not None
1901
1902         gitignore = path / "a" / ".gitignore"
1903         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
1904
1905     def test_empty_include(self) -> None:
1906         path = DATA_DIR / "include_exclude_tests"
1907         src = [path]
1908         expected = [
1909             Path(path / "b/exclude/a.pie"),
1910             Path(path / "b/exclude/a.py"),
1911             Path(path / "b/exclude/a.pyi"),
1912             Path(path / "b/dont_exclude/a.pie"),
1913             Path(path / "b/dont_exclude/a.py"),
1914             Path(path / "b/dont_exclude/a.pyi"),
1915             Path(path / "b/.definitely_exclude/a.pie"),
1916             Path(path / "b/.definitely_exclude/a.py"),
1917             Path(path / "b/.definitely_exclude/a.pyi"),
1918             Path(path / ".gitignore"),
1919             Path(path / "pyproject.toml"),
1920         ]
1921         # Setting exclude explicitly to an empty string to block .gitignore usage.
1922         assert_collected_sources(src, expected, include="", exclude="")
1923
1924     def test_extend_exclude(self) -> None:
1925         path = DATA_DIR / "include_exclude_tests"
1926         src = [path]
1927         expected = [
1928             Path(path / "b/exclude/a.py"),
1929             Path(path / "b/dont_exclude/a.py"),
1930         ]
1931         assert_collected_sources(
1932             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
1933         )
1934
1935     @pytest.mark.incompatible_with_mypyc
1936     def test_symlink_out_of_root_directory(self) -> None:
1937         path = MagicMock()
1938         root = THIS_DIR.resolve()
1939         child = MagicMock()
1940         include = re.compile(black.DEFAULT_INCLUDES)
1941         exclude = re.compile(black.DEFAULT_EXCLUDES)
1942         report = black.Report()
1943         gitignore = PathSpec.from_lines("gitwildmatch", [])
1944         # `child` should behave like a symlink which resolved path is clearly
1945         # outside of the `root` directory.
1946         path.iterdir.return_value = [child]
1947         child.resolve.return_value = Path("/a/b/c")
1948         child.as_posix.return_value = "/a/b/c"
1949         child.is_symlink.return_value = True
1950         try:
1951             list(
1952                 black.gen_python_files(
1953                     path.iterdir(),
1954                     root,
1955                     include,
1956                     exclude,
1957                     None,
1958                     None,
1959                     report,
1960                     gitignore,
1961                     verbose=False,
1962                     quiet=False,
1963                 )
1964             )
1965         except ValueError as ve:
1966             pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
1967         path.iterdir.assert_called_once()
1968         child.resolve.assert_called_once()
1969         child.is_symlink.assert_called_once()
1970         # `child` should behave like a strange file which resolved path is clearly
1971         # outside of the `root` directory.
1972         child.is_symlink.return_value = False
1973         with pytest.raises(ValueError):
1974             list(
1975                 black.gen_python_files(
1976                     path.iterdir(),
1977                     root,
1978                     include,
1979                     exclude,
1980                     None,
1981                     None,
1982                     report,
1983                     gitignore,
1984                     verbose=False,
1985                     quiet=False,
1986                 )
1987             )
1988         path.iterdir.assert_called()
1989         assert path.iterdir.call_count == 2
1990         child.resolve.assert_called()
1991         assert child.resolve.call_count == 2
1992         child.is_symlink.assert_called()
1993         assert child.is_symlink.call_count == 2
1994
1995     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1996     def test_get_sources_with_stdin(self) -> None:
1997         src = ["-"]
1998         expected = ["-"]
1999         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
2000
2001     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2002     def test_get_sources_with_stdin_filename(self) -> None:
2003         src = ["-"]
2004         stdin_filename = str(THIS_DIR / "data/collections.py")
2005         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2006         assert_collected_sources(
2007             src,
2008             expected,
2009             exclude=r"/exclude/a\.py",
2010             stdin_filename=stdin_filename,
2011         )
2012
2013     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2014     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
2015         # Exclude shouldn't exclude stdin_filename since it is mimicking the
2016         # file being passed directly. This is the same as
2017         # test_exclude_for_issue_1572
2018         path = DATA_DIR / "include_exclude_tests"
2019         src = ["-"]
2020         stdin_filename = str(path / "b/exclude/a.py")
2021         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2022         assert_collected_sources(
2023             src,
2024             expected,
2025             exclude=r"/exclude/|a\.py",
2026             stdin_filename=stdin_filename,
2027         )
2028
2029     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2030     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
2031         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
2032         # file being passed directly. This is the same as
2033         # test_exclude_for_issue_1572
2034         src = ["-"]
2035         path = THIS_DIR / "data" / "include_exclude_tests"
2036         stdin_filename = str(path / "b/exclude/a.py")
2037         expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
2038         assert_collected_sources(
2039             src,
2040             expected,
2041             extend_exclude=r"/exclude/|a\.py",
2042             stdin_filename=stdin_filename,
2043         )
2044
2045     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
2046     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
2047         # Force exclude should exclude the file when passing it through
2048         # stdin_filename
2049         path = THIS_DIR / "data" / "include_exclude_tests"
2050         stdin_filename = str(path / "b/exclude/a.py")
2051         assert_collected_sources(
2052             src=["-"],
2053             expected=[],
2054             force_exclude=r"/exclude/|a\.py",
2055             stdin_filename=stdin_filename,
2056         )
2057
2058
2059 @pytest.mark.python2
2060 @pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
2061 def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
2062     args = [
2063         "--config",
2064         str(THIS_DIR / "empty.toml"),
2065         str(DATA_DIR / "python2.py"),
2066         "--check",
2067     ]
2068     if explicit:
2069         args.append("--target-version=py27")
2070     with cache_dir():
2071         result = BlackRunner().invoke(black.main, args)
2072     assert "DEPRECATION: Python 2 support will be removed" in result.stderr
2073
2074
2075 @pytest.mark.python2
2076 def test_python_2_deprecation_autodetection_extended() -> None:
2077     # this test has a similar construction to test_get_features_used_decorator
2078     python2, non_python2 = read_data("python2_detection")
2079     for python2_case in python2.split("###"):
2080         node = black.lib2to3_parse(python2_case)
2081         assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
2082     for non_python2_case in non_python2.split("###"):
2083         node = black.lib2to3_parse(non_python2_case)
2084         assert black.detect_target_versions(node) != {
2085             TargetVersion.PY27
2086         }, non_python2_case
2087
2088
2089 try:
2090     with open(black.__file__, "r", encoding="utf-8") as _bf:
2091         black_source_lines = _bf.readlines()
2092 except UnicodeDecodeError:
2093     if not black.COMPILED:
2094         raise
2095
2096
2097 def tracefunc(
2098     frame: types.FrameType, event: str, arg: Any
2099 ) -> Callable[[types.FrameType, str, Any], Any]:
2100     """Show function calls `from black/__init__.py` as they happen.
2101
2102     Register this with `sys.settrace()` in a test you're debugging.
2103     """
2104     if event != "call":
2105         return tracefunc
2106
2107     stack = len(inspect.stack()) - 19
2108     stack *= 2
2109     filename = frame.f_code.co_filename
2110     lineno = frame.f_lineno
2111     func_sig_lineno = lineno - 1
2112     funcname = black_source_lines[func_sig_lineno].strip()
2113     while funcname.startswith("@"):
2114         func_sig_lineno += 1
2115         funcname = black_source_lines[func_sig_lineno].strip()
2116     if "black/__init__.py" in filename:
2117         print(f"{' ' * stack}{lineno}:{funcname}")
2118     return tracefunc