]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

29a0731dd9d467dc0bc53a229e17dedf84b87536
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Iterator,
25     TypeVar,
26 )
27 import pytest
28 import unittest
29 from unittest.mock import patch, MagicMock
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37 from black.cache import get_cache_file
38 from black.debug import DebugVisitor
39 from black.report import Report
40 import black.files
41
42 from pathspec import PathSpec
43
44 # Import other test classes
45 from tests.util import (
46     THIS_DIR,
47     read_data,
48     DETERMINISTIC_HEADER,
49     BlackBaseTestCase,
50     DEFAULT_MODE,
51     fs,
52     ff,
53     dump_to_stderr,
54 )
55 from .test_primer import PrimerCLITests  # noqa: F401
56
57
58 THIS_FILE = Path(__file__)
59 PY36_VERSIONS = {
60     TargetVersion.PY36,
61     TargetVersion.PY37,
62     TargetVersion.PY38,
63     TargetVersion.PY39,
64 }
65 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
66 T = TypeVar("T")
67 R = TypeVar("R")
68
69
70 @contextmanager
71 def cache_dir(exists: bool = True) -> Iterator[Path]:
72     with TemporaryDirectory() as workspace:
73         cache_dir = Path(workspace)
74         if not exists:
75             cache_dir = cache_dir / "new"
76         with patch("black.cache.CACHE_DIR", cache_dir):
77             yield cache_dir
78
79
80 @contextmanager
81 def event_loop() -> Iterator[None]:
82     policy = asyncio.get_event_loop_policy()
83     loop = policy.new_event_loop()
84     asyncio.set_event_loop(loop)
85     try:
86         yield
87
88     finally:
89         loop.close()
90
91
92 class FakeContext(click.Context):
93     """A fake click Context for when calling functions that need it."""
94
95     def __init__(self) -> None:
96         self.default_map: Dict[str, Any] = {}
97
98
99 class FakeParameter(click.Parameter):
100     """A fake click Parameter for when calling functions that need it."""
101
102     def __init__(self) -> None:
103         pass
104
105
106 class BlackRunner(CliRunner):
107     """Modify CliRunner so that stderr is not merged with stdout.
108
109     This is a hack that can be removed once we depend on Click 7.x"""
110
111     def __init__(self) -> None:
112         self.stderrbuf = BytesIO()
113         self.stdoutbuf = BytesIO()
114         self.stdout_bytes = b""
115         self.stderr_bytes = b""
116         super().__init__()
117
118     @contextmanager
119     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
120         with super().isolation(*args, **kwargs) as output:
121             try:
122                 hold_stderr = sys.stderr
123                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
124                 yield output
125             finally:
126                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
127                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
128                 sys.stderr = hold_stderr
129
130
131 class BlackTestCase(BlackBaseTestCase):
132     def invokeBlack(
133         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
134     ) -> None:
135         runner = BlackRunner()
136         if ignore_config:
137             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
138         result = runner.invoke(black.main, args)
139         self.assertEqual(
140             result.exit_code,
141             exit_code,
142             msg=(
143                 f"Failed with args: {args}\n"
144                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
145                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
146                 f"exception: {result.exception}"
147             ),
148         )
149
150     @patch("black.dump_to_file", dump_to_stderr)
151     def test_empty(self) -> None:
152         source = expected = ""
153         actual = fs(source)
154         self.assertFormatEqual(expected, actual)
155         black.assert_equivalent(source, actual)
156         black.assert_stable(source, actual, DEFAULT_MODE)
157
158     def test_empty_ff(self) -> None:
159         expected = ""
160         tmp_file = Path(black.dump_to_file())
161         try:
162             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
163             with open(tmp_file, encoding="utf8") as f:
164                 actual = f.read()
165         finally:
166             os.unlink(tmp_file)
167         self.assertFormatEqual(expected, actual)
168
169     def test_piping(self) -> None:
170         source, expected = read_data("src/black/__init__", data=False)
171         result = BlackRunner().invoke(
172             black.main,
173             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
174             input=BytesIO(source.encode("utf8")),
175         )
176         self.assertEqual(result.exit_code, 0)
177         self.assertFormatEqual(expected, result.output)
178         black.assert_equivalent(source, result.output)
179         black.assert_stable(source, result.output, DEFAULT_MODE)
180
181     def test_piping_diff(self) -> None:
182         diff_header = re.compile(
183             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
184             r"\+\d\d\d\d"
185         )
186         source, _ = read_data("expression.py")
187         expected, _ = read_data("expression.diff")
188         config = THIS_DIR / "data" / "empty_pyproject.toml"
189         args = [
190             "-",
191             "--fast",
192             f"--line-length={black.DEFAULT_LINE_LENGTH}",
193             "--diff",
194             f"--config={config}",
195         ]
196         result = BlackRunner().invoke(
197             black.main, args, input=BytesIO(source.encode("utf8"))
198         )
199         self.assertEqual(result.exit_code, 0)
200         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
201         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
202         self.assertEqual(expected, actual)
203
204     def test_piping_diff_with_color(self) -> None:
205         source, _ = read_data("expression.py")
206         config = THIS_DIR / "data" / "empty_pyproject.toml"
207         args = [
208             "-",
209             "--fast",
210             f"--line-length={black.DEFAULT_LINE_LENGTH}",
211             "--diff",
212             "--color",
213             f"--config={config}",
214         ]
215         result = BlackRunner().invoke(
216             black.main, args, input=BytesIO(source.encode("utf8"))
217         )
218         actual = result.output
219         # Again, the contents are checked in a different test, so only look for colors.
220         self.assertIn("\033[1;37m", actual)
221         self.assertIn("\033[36m", actual)
222         self.assertIn("\033[32m", actual)
223         self.assertIn("\033[31m", actual)
224         self.assertIn("\033[0m", actual)
225
226     @patch("black.dump_to_file", dump_to_stderr)
227     def _test_wip(self) -> None:
228         source, expected = read_data("wip")
229         sys.settrace(tracefunc)
230         mode = replace(
231             DEFAULT_MODE,
232             experimental_string_processing=False,
233             target_versions={black.TargetVersion.PY38},
234         )
235         actual = fs(source, mode=mode)
236         sys.settrace(None)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, black.FileMode())
240
241     @unittest.expectedFailure
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_trailing_comma_optional_parens_stability1(self) -> None:
244         source, _expected = read_data("trailing_comma_optional_parens1")
245         actual = fs(source)
246         black.assert_stable(source, actual, DEFAULT_MODE)
247
248     @unittest.expectedFailure
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_trailing_comma_optional_parens_stability2(self) -> None:
251         source, _expected = read_data("trailing_comma_optional_parens2")
252         actual = fs(source)
253         black.assert_stable(source, actual, DEFAULT_MODE)
254
255     @unittest.expectedFailure
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_trailing_comma_optional_parens_stability3(self) -> None:
258         source, _expected = read_data("trailing_comma_optional_parens3")
259         actual = fs(source)
260         black.assert_stable(source, actual, DEFAULT_MODE)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
264         source, _expected = read_data("trailing_comma_optional_parens1")
265         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
266         black.assert_stable(source, actual, DEFAULT_MODE)
267
268     @patch("black.dump_to_file", dump_to_stderr)
269     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
270         source, _expected = read_data("trailing_comma_optional_parens2")
271         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
272         black.assert_stable(source, actual, DEFAULT_MODE)
273
274     @patch("black.dump_to_file", dump_to_stderr)
275     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
276         source, _expected = read_data("trailing_comma_optional_parens3")
277         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
278         black.assert_stable(source, actual, DEFAULT_MODE)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_pep_572(self) -> None:
282         source, expected = read_data("pep_572")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_stable(source, actual, DEFAULT_MODE)
286         if sys.version_info >= (3, 8):
287             black.assert_equivalent(source, actual)
288
289     @patch("black.dump_to_file", dump_to_stderr)
290     def test_pep_572_remove_parens(self) -> None:
291         source, expected = read_data("pep_572_remove_parens")
292         actual = fs(source)
293         self.assertFormatEqual(expected, actual)
294         black.assert_stable(source, actual, DEFAULT_MODE)
295         if sys.version_info >= (3, 8):
296             black.assert_equivalent(source, actual)
297
298     @patch("black.dump_to_file", dump_to_stderr)
299     def test_pep_572_do_not_remove_parens(self) -> None:
300         source, expected = read_data("pep_572_do_not_remove_parens")
301         # the AST safety checks will fail, but that's expected, just make sure no
302         # parentheses are touched
303         actual = black.format_str(source, mode=DEFAULT_MODE)
304         self.assertFormatEqual(expected, actual)
305
306     def test_pep_572_version_detection(self) -> None:
307         source, _ = read_data("pep_572")
308         root = black.lib2to3_parse(source)
309         features = black.get_features_used(root)
310         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
311         versions = black.detect_target_versions(root)
312         self.assertIn(black.TargetVersion.PY38, versions)
313
314     def test_expression_ff(self) -> None:
315         source, expected = read_data("expression")
316         tmp_file = Path(black.dump_to_file(source))
317         try:
318             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
319             with open(tmp_file, encoding="utf8") as f:
320                 actual = f.read()
321         finally:
322             os.unlink(tmp_file)
323         self.assertFormatEqual(expected, actual)
324         with patch("black.dump_to_file", dump_to_stderr):
325             black.assert_equivalent(source, actual)
326             black.assert_stable(source, actual, DEFAULT_MODE)
327
328     def test_expression_diff(self) -> None:
329         source, _ = read_data("expression.py")
330         config = THIS_DIR / "data" / "empty_pyproject.toml"
331         expected, _ = read_data("expression.diff")
332         tmp_file = Path(black.dump_to_file(source))
333         diff_header = re.compile(
334             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
335             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
336         )
337         try:
338             result = BlackRunner().invoke(
339                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
340             )
341             self.assertEqual(result.exit_code, 0)
342         finally:
343             os.unlink(tmp_file)
344         actual = result.output
345         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
346         if expected != actual:
347             dump = black.dump_to_file(actual)
348             msg = (
349                 "Expected diff isn't equal to the actual. If you made changes to"
350                 " expression.py and this is an anticipated difference, overwrite"
351                 f" tests/data/expression.diff with {dump}"
352             )
353             self.assertEqual(expected, actual, msg)
354
355     def test_expression_diff_with_color(self) -> None:
356         source, _ = read_data("expression.py")
357         config = THIS_DIR / "data" / "empty_pyproject.toml"
358         expected, _ = read_data("expression.diff")
359         tmp_file = Path(black.dump_to_file(source))
360         try:
361             result = BlackRunner().invoke(
362                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
363             )
364         finally:
365             os.unlink(tmp_file)
366         actual = result.output
367         # We check the contents of the diff in `test_expression_diff`. All
368         # we need to check here is that color codes exist in the result.
369         self.assertIn("\033[1;37m", actual)
370         self.assertIn("\033[36m", actual)
371         self.assertIn("\033[32m", actual)
372         self.assertIn("\033[31m", actual)
373         self.assertIn("\033[0m", actual)
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_pep_570(self) -> None:
377         source, expected = read_data("pep_570")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_stable(source, actual, DEFAULT_MODE)
381         if sys.version_info >= (3, 8):
382             black.assert_equivalent(source, actual)
383
384     def test_detect_pos_only_arguments(self) -> None:
385         source, _ = read_data("pep_570")
386         root = black.lib2to3_parse(source)
387         features = black.get_features_used(root)
388         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
389         versions = black.detect_target_versions(root)
390         self.assertIn(black.TargetVersion.PY38, versions)
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_string_quotes(self) -> None:
394         source, expected = read_data("string_quotes")
395         mode = black.Mode(experimental_string_processing=True)
396         actual = fs(source, mode=mode)
397         self.assertFormatEqual(expected, actual)
398         black.assert_equivalent(source, actual)
399         black.assert_stable(source, actual, mode)
400         mode = replace(mode, string_normalization=False)
401         not_normalized = fs(source, mode=mode)
402         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
403         black.assert_equivalent(source, not_normalized)
404         black.assert_stable(source, not_normalized, mode=mode)
405
406     @patch("black.dump_to_file", dump_to_stderr)
407     def test_docstring_no_string_normalization(self) -> None:
408         """Like test_docstring but with string normalization off."""
409         source, expected = read_data("docstring_no_string_normalization")
410         mode = replace(DEFAULT_MODE, string_normalization=False)
411         actual = fs(source, mode=mode)
412         self.assertFormatEqual(expected, actual)
413         black.assert_equivalent(source, actual)
414         black.assert_stable(source, actual, mode)
415
416     def test_long_strings_flag_disabled(self) -> None:
417         """Tests for turning off the string processing logic."""
418         source, expected = read_data("long_strings_flag_disabled")
419         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
420         actual = fs(source, mode=mode)
421         self.assertFormatEqual(expected, actual)
422         black.assert_stable(expected, actual, mode)
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_numeric_literals(self) -> None:
426         source, expected = read_data("numeric_literals")
427         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
428         actual = fs(source, mode=mode)
429         self.assertFormatEqual(expected, actual)
430         black.assert_equivalent(source, actual)
431         black.assert_stable(source, actual, mode)
432
433     @patch("black.dump_to_file", dump_to_stderr)
434     def test_numeric_literals_ignoring_underscores(self) -> None:
435         source, expected = read_data("numeric_literals_skip_underscores")
436         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
437         actual = fs(source, mode=mode)
438         self.assertFormatEqual(expected, actual)
439         black.assert_equivalent(source, actual)
440         black.assert_stable(source, actual, mode)
441
442     def test_skip_magic_trailing_comma(self) -> None:
443         source, _ = read_data("expression.py")
444         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
445         tmp_file = Path(black.dump_to_file(source))
446         diff_header = re.compile(
447             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
448             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
449         )
450         try:
451             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
452             self.assertEqual(result.exit_code, 0)
453         finally:
454             os.unlink(tmp_file)
455         actual = result.output
456         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
457         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
458         if expected != actual:
459             dump = black.dump_to_file(actual)
460             msg = (
461                 "Expected diff isn't equal to the actual. If you made changes to"
462                 " expression.py and this is an anticipated difference, overwrite"
463                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
464             )
465             self.assertEqual(expected, actual, msg)
466
467     @pytest.mark.no_python2
468     def test_python2_should_fail_without_optional_install(self) -> None:
469         if sys.version_info < (3, 8):
470             self.skipTest(
471                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
472                 " able to parse Python 2 syntax without explicitly specifying the"
473                 " python2 extra"
474             )
475
476         source = "x = 1234l"
477         tmp_file = Path(black.dump_to_file(source))
478         try:
479             runner = BlackRunner()
480             result = runner.invoke(black.main, [str(tmp_file)])
481             self.assertEqual(result.exit_code, 123)
482         finally:
483             os.unlink(tmp_file)
484         actual = (
485             runner.stderr_bytes.decode()
486             .replace("\n", "")
487             .replace("\\n", "")
488             .replace("\\r", "")
489             .replace("\r", "")
490         )
491         msg = (
492             "The requested source code has invalid Python 3 syntax."
493             "If you are trying to format Python 2 files please reinstall Black"
494             " with the 'python2' extra: `python3 -m pip install black[python2]`."
495         )
496         self.assertIn(msg, actual)
497
498     @pytest.mark.python2
499     @patch("black.dump_to_file", dump_to_stderr)
500     def test_python2_print_function(self) -> None:
501         source, expected = read_data("python2_print_function")
502         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
503         actual = fs(source, mode=mode)
504         self.assertFormatEqual(expected, actual)
505         black.assert_equivalent(source, actual)
506         black.assert_stable(source, actual, mode)
507
508     @patch("black.dump_to_file", dump_to_stderr)
509     def test_stub(self) -> None:
510         mode = replace(DEFAULT_MODE, is_pyi=True)
511         source, expected = read_data("stub.pyi")
512         actual = fs(source, mode=mode)
513         self.assertFormatEqual(expected, actual)
514         black.assert_stable(source, actual, mode)
515
516     @patch("black.dump_to_file", dump_to_stderr)
517     def test_async_as_identifier(self) -> None:
518         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
519         source, expected = read_data("async_as_identifier")
520         actual = fs(source)
521         self.assertFormatEqual(expected, actual)
522         major, minor = sys.version_info[:2]
523         if major < 3 or (major <= 3 and minor < 7):
524             black.assert_equivalent(source, actual)
525         black.assert_stable(source, actual, DEFAULT_MODE)
526         # ensure black can parse this when the target is 3.6
527         self.invokeBlack([str(source_path), "--target-version", "py36"])
528         # but not on 3.7, because async/await is no longer an identifier
529         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
530
531     @patch("black.dump_to_file", dump_to_stderr)
532     def test_python37(self) -> None:
533         source_path = (THIS_DIR / "data" / "python37.py").resolve()
534         source, expected = read_data("python37")
535         actual = fs(source)
536         self.assertFormatEqual(expected, actual)
537         major, minor = sys.version_info[:2]
538         if major > 3 or (major == 3 and minor >= 7):
539             black.assert_equivalent(source, actual)
540         black.assert_stable(source, actual, DEFAULT_MODE)
541         # ensure black can parse this when the target is 3.7
542         self.invokeBlack([str(source_path), "--target-version", "py37"])
543         # but not on 3.6, because we use async as a reserved keyword
544         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
545
546     @patch("black.dump_to_file", dump_to_stderr)
547     def test_python38(self) -> None:
548         source, expected = read_data("python38")
549         actual = fs(source)
550         self.assertFormatEqual(expected, actual)
551         major, minor = sys.version_info[:2]
552         if major > 3 or (major == 3 and minor >= 8):
553             black.assert_equivalent(source, actual)
554         black.assert_stable(source, actual, DEFAULT_MODE)
555
556     @patch("black.dump_to_file", dump_to_stderr)
557     def test_python39(self) -> None:
558         source, expected = read_data("python39")
559         actual = fs(source)
560         self.assertFormatEqual(expected, actual)
561         major, minor = sys.version_info[:2]
562         if major > 3 or (major == 3 and minor >= 9):
563             black.assert_equivalent(source, actual)
564         black.assert_stable(source, actual, DEFAULT_MODE)
565
566     def test_tab_comment_indentation(self) -> None:
567         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
568         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
569         self.assertFormatEqual(contents_spc, fs(contents_spc))
570         self.assertFormatEqual(contents_spc, fs(contents_tab))
571
572         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
573         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
574         self.assertFormatEqual(contents_spc, fs(contents_spc))
575         self.assertFormatEqual(contents_spc, fs(contents_tab))
576
577         # mixed tabs and spaces (valid Python 2 code)
578         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
579         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
580         self.assertFormatEqual(contents_spc, fs(contents_spc))
581         self.assertFormatEqual(contents_spc, fs(contents_tab))
582
583         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
584         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
585         self.assertFormatEqual(contents_spc, fs(contents_spc))
586         self.assertFormatEqual(contents_spc, fs(contents_tab))
587
588     def test_report_verbose(self) -> None:
589         report = Report(verbose=True)
590         out_lines = []
591         err_lines = []
592
593         def out(msg: str, **kwargs: Any) -> None:
594             out_lines.append(msg)
595
596         def err(msg: str, **kwargs: Any) -> None:
597             err_lines.append(msg)
598
599         with patch("black.output._out", out), patch("black.output._err", err):
600             report.done(Path("f1"), black.Changed.NO)
601             self.assertEqual(len(out_lines), 1)
602             self.assertEqual(len(err_lines), 0)
603             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
604             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
605             self.assertEqual(report.return_code, 0)
606             report.done(Path("f2"), black.Changed.YES)
607             self.assertEqual(len(out_lines), 2)
608             self.assertEqual(len(err_lines), 0)
609             self.assertEqual(out_lines[-1], "reformatted f2")
610             self.assertEqual(
611                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
612             )
613             report.done(Path("f3"), black.Changed.CACHED)
614             self.assertEqual(len(out_lines), 3)
615             self.assertEqual(len(err_lines), 0)
616             self.assertEqual(
617                 out_lines[-1], "f3 wasn't modified on disk since last run."
618             )
619             self.assertEqual(
620                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
621             )
622             self.assertEqual(report.return_code, 0)
623             report.check = True
624             self.assertEqual(report.return_code, 1)
625             report.check = False
626             report.failed(Path("e1"), "boom")
627             self.assertEqual(len(out_lines), 3)
628             self.assertEqual(len(err_lines), 1)
629             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
630             self.assertEqual(
631                 unstyle(str(report)),
632                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
633                 " reformat.",
634             )
635             self.assertEqual(report.return_code, 123)
636             report.done(Path("f3"), black.Changed.YES)
637             self.assertEqual(len(out_lines), 4)
638             self.assertEqual(len(err_lines), 1)
639             self.assertEqual(out_lines[-1], "reformatted f3")
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
643                 " reformat.",
644             )
645             self.assertEqual(report.return_code, 123)
646             report.failed(Path("e2"), "boom")
647             self.assertEqual(len(out_lines), 4)
648             self.assertEqual(len(err_lines), 2)
649             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
650             self.assertEqual(
651                 unstyle(str(report)),
652                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
653                 " reformat.",
654             )
655             self.assertEqual(report.return_code, 123)
656             report.path_ignored(Path("wat"), "no match")
657             self.assertEqual(len(out_lines), 5)
658             self.assertEqual(len(err_lines), 2)
659             self.assertEqual(out_lines[-1], "wat ignored: no match")
660             self.assertEqual(
661                 unstyle(str(report)),
662                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
663                 " reformat.",
664             )
665             self.assertEqual(report.return_code, 123)
666             report.done(Path("f4"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 6)
668             self.assertEqual(len(err_lines), 2)
669             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
670             self.assertEqual(
671                 unstyle(str(report)),
672                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
673                 " reformat.",
674             )
675             self.assertEqual(report.return_code, 123)
676             report.check = True
677             self.assertEqual(
678                 unstyle(str(report)),
679                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
680                 " would fail to reformat.",
681             )
682             report.check = False
683             report.diff = True
684             self.assertEqual(
685                 unstyle(str(report)),
686                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
687                 " would fail to reformat.",
688             )
689
690     def test_report_quiet(self) -> None:
691         report = Report(quiet=True)
692         out_lines = []
693         err_lines = []
694
695         def out(msg: str, **kwargs: Any) -> None:
696             out_lines.append(msg)
697
698         def err(msg: str, **kwargs: Any) -> None:
699             err_lines.append(msg)
700
701         with patch("black.output._out", out), patch("black.output._err", err):
702             report.done(Path("f1"), black.Changed.NO)
703             self.assertEqual(len(out_lines), 0)
704             self.assertEqual(len(err_lines), 0)
705             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
706             self.assertEqual(report.return_code, 0)
707             report.done(Path("f2"), black.Changed.YES)
708             self.assertEqual(len(out_lines), 0)
709             self.assertEqual(len(err_lines), 0)
710             self.assertEqual(
711                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
712             )
713             report.done(Path("f3"), black.Changed.CACHED)
714             self.assertEqual(len(out_lines), 0)
715             self.assertEqual(len(err_lines), 0)
716             self.assertEqual(
717                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
718             )
719             self.assertEqual(report.return_code, 0)
720             report.check = True
721             self.assertEqual(report.return_code, 1)
722             report.check = False
723             report.failed(Path("e1"), "boom")
724             self.assertEqual(len(out_lines), 0)
725             self.assertEqual(len(err_lines), 1)
726             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
727             self.assertEqual(
728                 unstyle(str(report)),
729                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
730                 " reformat.",
731             )
732             self.assertEqual(report.return_code, 123)
733             report.done(Path("f3"), black.Changed.YES)
734             self.assertEqual(len(out_lines), 0)
735             self.assertEqual(len(err_lines), 1)
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
739                 " reformat.",
740             )
741             self.assertEqual(report.return_code, 123)
742             report.failed(Path("e2"), "boom")
743             self.assertEqual(len(out_lines), 0)
744             self.assertEqual(len(err_lines), 2)
745             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
746             self.assertEqual(
747                 unstyle(str(report)),
748                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
749                 " reformat.",
750             )
751             self.assertEqual(report.return_code, 123)
752             report.path_ignored(Path("wat"), "no match")
753             self.assertEqual(len(out_lines), 0)
754             self.assertEqual(len(err_lines), 2)
755             self.assertEqual(
756                 unstyle(str(report)),
757                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
758                 " reformat.",
759             )
760             self.assertEqual(report.return_code, 123)
761             report.done(Path("f4"), black.Changed.NO)
762             self.assertEqual(len(out_lines), 0)
763             self.assertEqual(len(err_lines), 2)
764             self.assertEqual(
765                 unstyle(str(report)),
766                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
767                 " reformat.",
768             )
769             self.assertEqual(report.return_code, 123)
770             report.check = True
771             self.assertEqual(
772                 unstyle(str(report)),
773                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
774                 " would fail to reformat.",
775             )
776             report.check = False
777             report.diff = True
778             self.assertEqual(
779                 unstyle(str(report)),
780                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
781                 " would fail to reformat.",
782             )
783
784     def test_report_normal(self) -> None:
785         report = black.Report()
786         out_lines = []
787         err_lines = []
788
789         def out(msg: str, **kwargs: Any) -> None:
790             out_lines.append(msg)
791
792         def err(msg: str, **kwargs: Any) -> None:
793             err_lines.append(msg)
794
795         with patch("black.output._out", out), patch("black.output._err", err):
796             report.done(Path("f1"), black.Changed.NO)
797             self.assertEqual(len(out_lines), 0)
798             self.assertEqual(len(err_lines), 0)
799             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
800             self.assertEqual(report.return_code, 0)
801             report.done(Path("f2"), black.Changed.YES)
802             self.assertEqual(len(out_lines), 1)
803             self.assertEqual(len(err_lines), 0)
804             self.assertEqual(out_lines[-1], "reformatted f2")
805             self.assertEqual(
806                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
807             )
808             report.done(Path("f3"), black.Changed.CACHED)
809             self.assertEqual(len(out_lines), 1)
810             self.assertEqual(len(err_lines), 0)
811             self.assertEqual(out_lines[-1], "reformatted f2")
812             self.assertEqual(
813                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
814             )
815             self.assertEqual(report.return_code, 0)
816             report.check = True
817             self.assertEqual(report.return_code, 1)
818             report.check = False
819             report.failed(Path("e1"), "boom")
820             self.assertEqual(len(out_lines), 1)
821             self.assertEqual(len(err_lines), 1)
822             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
823             self.assertEqual(
824                 unstyle(str(report)),
825                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
826                 " reformat.",
827             )
828             self.assertEqual(report.return_code, 123)
829             report.done(Path("f3"), black.Changed.YES)
830             self.assertEqual(len(out_lines), 2)
831             self.assertEqual(len(err_lines), 1)
832             self.assertEqual(out_lines[-1], "reformatted f3")
833             self.assertEqual(
834                 unstyle(str(report)),
835                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
836                 " reformat.",
837             )
838             self.assertEqual(report.return_code, 123)
839             report.failed(Path("e2"), "boom")
840             self.assertEqual(len(out_lines), 2)
841             self.assertEqual(len(err_lines), 2)
842             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
843             self.assertEqual(
844                 unstyle(str(report)),
845                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
846                 " reformat.",
847             )
848             self.assertEqual(report.return_code, 123)
849             report.path_ignored(Path("wat"), "no match")
850             self.assertEqual(len(out_lines), 2)
851             self.assertEqual(len(err_lines), 2)
852             self.assertEqual(
853                 unstyle(str(report)),
854                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
855                 " reformat.",
856             )
857             self.assertEqual(report.return_code, 123)
858             report.done(Path("f4"), black.Changed.NO)
859             self.assertEqual(len(out_lines), 2)
860             self.assertEqual(len(err_lines), 2)
861             self.assertEqual(
862                 unstyle(str(report)),
863                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
864                 " reformat.",
865             )
866             self.assertEqual(report.return_code, 123)
867             report.check = True
868             self.assertEqual(
869                 unstyle(str(report)),
870                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
871                 " would fail to reformat.",
872             )
873             report.check = False
874             report.diff = True
875             self.assertEqual(
876                 unstyle(str(report)),
877                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
878                 " would fail to reformat.",
879             )
880
881     def test_lib2to3_parse(self) -> None:
882         with self.assertRaises(black.InvalidInput):
883             black.lib2to3_parse("invalid syntax")
884
885         straddling = "x + y"
886         black.lib2to3_parse(straddling)
887         black.lib2to3_parse(straddling, {TargetVersion.PY27})
888         black.lib2to3_parse(straddling, {TargetVersion.PY36})
889         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
890
891         py2_only = "print x"
892         black.lib2to3_parse(py2_only)
893         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
894         with self.assertRaises(black.InvalidInput):
895             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
896         with self.assertRaises(black.InvalidInput):
897             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
898
899         py3_only = "exec(x, end=y)"
900         black.lib2to3_parse(py3_only)
901         with self.assertRaises(black.InvalidInput):
902             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
903         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
904         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
905
906     def test_get_features_used_decorator(self) -> None:
907         # Test the feature detection of new decorator syntax
908         # since this makes some test cases of test_get_features_used()
909         # fails if it fails, this is tested first so that a useful case
910         # is identified
911         simples, relaxed = read_data("decorators")
912         # skip explanation comments at the top of the file
913         for simple_test in simples.split("##")[1:]:
914             node = black.lib2to3_parse(simple_test)
915             decorator = str(node.children[0].children[0]).strip()
916             self.assertNotIn(
917                 Feature.RELAXED_DECORATORS,
918                 black.get_features_used(node),
919                 msg=(
920                     f"decorator '{decorator}' follows python<=3.8 syntax"
921                     "but is detected as 3.9+"
922                     # f"The full node is\n{node!r}"
923                 ),
924             )
925         # skip the '# output' comment at the top of the output part
926         for relaxed_test in relaxed.split("##")[1:]:
927             node = black.lib2to3_parse(relaxed_test)
928             decorator = str(node.children[0].children[0]).strip()
929             self.assertIn(
930                 Feature.RELAXED_DECORATORS,
931                 black.get_features_used(node),
932                 msg=(
933                     f"decorator '{decorator}' uses python3.9+ syntax"
934                     "but is detected as python<=3.8"
935                     # f"The full node is\n{node!r}"
936                 ),
937             )
938
939     def test_get_features_used(self) -> None:
940         node = black.lib2to3_parse("def f(*, arg): ...\n")
941         self.assertEqual(black.get_features_used(node), set())
942         node = black.lib2to3_parse("def f(*, arg,): ...\n")
943         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
944         node = black.lib2to3_parse("f(*arg,)\n")
945         self.assertEqual(
946             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
947         )
948         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
949         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
950         node = black.lib2to3_parse("123_456\n")
951         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
952         node = black.lib2to3_parse("123456\n")
953         self.assertEqual(black.get_features_used(node), set())
954         source, expected = read_data("function")
955         node = black.lib2to3_parse(source)
956         expected_features = {
957             Feature.TRAILING_COMMA_IN_CALL,
958             Feature.TRAILING_COMMA_IN_DEF,
959             Feature.F_STRINGS,
960         }
961         self.assertEqual(black.get_features_used(node), expected_features)
962         node = black.lib2to3_parse(expected)
963         self.assertEqual(black.get_features_used(node), expected_features)
964         source, expected = read_data("expression")
965         node = black.lib2to3_parse(source)
966         self.assertEqual(black.get_features_used(node), set())
967         node = black.lib2to3_parse(expected)
968         self.assertEqual(black.get_features_used(node), set())
969
970     def test_get_future_imports(self) -> None:
971         node = black.lib2to3_parse("\n")
972         self.assertEqual(set(), black.get_future_imports(node))
973         node = black.lib2to3_parse("from __future__ import black\n")
974         self.assertEqual({"black"}, black.get_future_imports(node))
975         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
976         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
977         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
978         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
979         node = black.lib2to3_parse(
980             "from __future__ import multiple\nfrom __future__ import imports\n"
981         )
982         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
983         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
984         self.assertEqual({"black"}, black.get_future_imports(node))
985         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
986         self.assertEqual({"black"}, black.get_future_imports(node))
987         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
988         self.assertEqual(set(), black.get_future_imports(node))
989         node = black.lib2to3_parse("from some.module import black\n")
990         self.assertEqual(set(), black.get_future_imports(node))
991         node = black.lib2to3_parse(
992             "from __future__ import unicode_literals as _unicode_literals"
993         )
994         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
995         node = black.lib2to3_parse(
996             "from __future__ import unicode_literals as _lol, print"
997         )
998         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
999
1000     def test_debug_visitor(self) -> None:
1001         source, _ = read_data("debug_visitor.py")
1002         expected, _ = read_data("debug_visitor.out")
1003         out_lines = []
1004         err_lines = []
1005
1006         def out(msg: str, **kwargs: Any) -> None:
1007             out_lines.append(msg)
1008
1009         def err(msg: str, **kwargs: Any) -> None:
1010             err_lines.append(msg)
1011
1012         with patch("black.debug.out", out):
1013             DebugVisitor.show(source)
1014         actual = "\n".join(out_lines) + "\n"
1015         log_name = ""
1016         if expected != actual:
1017             log_name = black.dump_to_file(*out_lines)
1018         self.assertEqual(
1019             expected,
1020             actual,
1021             f"AST print out is different. Actual version dumped to {log_name}",
1022         )
1023
1024     def test_format_file_contents(self) -> None:
1025         empty = ""
1026         mode = DEFAULT_MODE
1027         with self.assertRaises(black.NothingChanged):
1028             black.format_file_contents(empty, mode=mode, fast=False)
1029         just_nl = "\n"
1030         with self.assertRaises(black.NothingChanged):
1031             black.format_file_contents(just_nl, mode=mode, fast=False)
1032         same = "j = [1, 2, 3]\n"
1033         with self.assertRaises(black.NothingChanged):
1034             black.format_file_contents(same, mode=mode, fast=False)
1035         different = "j = [1,2,3]"
1036         expected = same
1037         actual = black.format_file_contents(different, mode=mode, fast=False)
1038         self.assertEqual(expected, actual)
1039         invalid = "return if you can"
1040         with self.assertRaises(black.InvalidInput) as e:
1041             black.format_file_contents(invalid, mode=mode, fast=False)
1042         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1043
1044     def test_endmarker(self) -> None:
1045         n = black.lib2to3_parse("\n")
1046         self.assertEqual(n.type, black.syms.file_input)
1047         self.assertEqual(len(n.children), 1)
1048         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1049
1050     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1051     def test_assertFormatEqual(self) -> None:
1052         out_lines = []
1053         err_lines = []
1054
1055         def out(msg: str, **kwargs: Any) -> None:
1056             out_lines.append(msg)
1057
1058         def err(msg: str, **kwargs: Any) -> None:
1059             err_lines.append(msg)
1060
1061         with patch("black.output._out", out), patch("black.output._err", err):
1062             with self.assertRaises(AssertionError):
1063                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1064
1065         out_str = "".join(out_lines)
1066         self.assertTrue("Expected tree:" in out_str)
1067         self.assertTrue("Actual tree:" in out_str)
1068         self.assertEqual("".join(err_lines), "")
1069
1070     def test_cache_broken_file(self) -> None:
1071         mode = DEFAULT_MODE
1072         with cache_dir() as workspace:
1073             cache_file = get_cache_file(mode)
1074             with cache_file.open("w") as fobj:
1075                 fobj.write("this is not a pickle")
1076             self.assertEqual(black.read_cache(mode), {})
1077             src = (workspace / "test.py").resolve()
1078             with src.open("w") as fobj:
1079                 fobj.write("print('hello')")
1080             self.invokeBlack([str(src)])
1081             cache = black.read_cache(mode)
1082             self.assertIn(str(src), cache)
1083
1084     def test_cache_single_file_already_cached(self) -> None:
1085         mode = DEFAULT_MODE
1086         with cache_dir() as workspace:
1087             src = (workspace / "test.py").resolve()
1088             with src.open("w") as fobj:
1089                 fobj.write("print('hello')")
1090             black.write_cache({}, [src], mode)
1091             self.invokeBlack([str(src)])
1092             with src.open("r") as fobj:
1093                 self.assertEqual(fobj.read(), "print('hello')")
1094
1095     @event_loop()
1096     def test_cache_multiple_files(self) -> None:
1097         mode = DEFAULT_MODE
1098         with cache_dir() as workspace, patch(
1099             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1100         ):
1101             one = (workspace / "one.py").resolve()
1102             with one.open("w") as fobj:
1103                 fobj.write("print('hello')")
1104             two = (workspace / "two.py").resolve()
1105             with two.open("w") as fobj:
1106                 fobj.write("print('hello')")
1107             black.write_cache({}, [one], mode)
1108             self.invokeBlack([str(workspace)])
1109             with one.open("r") as fobj:
1110                 self.assertEqual(fobj.read(), "print('hello')")
1111             with two.open("r") as fobj:
1112                 self.assertEqual(fobj.read(), 'print("hello")\n')
1113             cache = black.read_cache(mode)
1114             self.assertIn(str(one), cache)
1115             self.assertIn(str(two), cache)
1116
1117     def test_no_cache_when_writeback_diff(self) -> None:
1118         mode = DEFAULT_MODE
1119         with cache_dir() as workspace:
1120             src = (workspace / "test.py").resolve()
1121             with src.open("w") as fobj:
1122                 fobj.write("print('hello')")
1123             with patch("black.read_cache") as read_cache, patch(
1124                 "black.write_cache"
1125             ) as write_cache:
1126                 self.invokeBlack([str(src), "--diff"])
1127                 cache_file = get_cache_file(mode)
1128                 self.assertFalse(cache_file.exists())
1129                 write_cache.assert_not_called()
1130                 read_cache.assert_not_called()
1131
1132     def test_no_cache_when_writeback_color_diff(self) -> None:
1133         mode = DEFAULT_MODE
1134         with cache_dir() as workspace:
1135             src = (workspace / "test.py").resolve()
1136             with src.open("w") as fobj:
1137                 fobj.write("print('hello')")
1138             with patch("black.read_cache") as read_cache, patch(
1139                 "black.write_cache"
1140             ) as write_cache:
1141                 self.invokeBlack([str(src), "--diff", "--color"])
1142                 cache_file = get_cache_file(mode)
1143                 self.assertFalse(cache_file.exists())
1144                 write_cache.assert_not_called()
1145                 read_cache.assert_not_called()
1146
1147     @event_loop()
1148     def test_output_locking_when_writeback_diff(self) -> None:
1149         with cache_dir() as workspace:
1150             for tag in range(0, 4):
1151                 src = (workspace / f"test{tag}.py").resolve()
1152                 with src.open("w") as fobj:
1153                     fobj.write("print('hello')")
1154             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1155                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1156                 # this isn't quite doing what we want, but if it _isn't_
1157                 # called then we cannot be using the lock it provides
1158                 mgr.assert_called()
1159
1160     @event_loop()
1161     def test_output_locking_when_writeback_color_diff(self) -> None:
1162         with cache_dir() as workspace:
1163             for tag in range(0, 4):
1164                 src = (workspace / f"test{tag}.py").resolve()
1165                 with src.open("w") as fobj:
1166                     fobj.write("print('hello')")
1167             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1168                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1169                 # this isn't quite doing what we want, but if it _isn't_
1170                 # called then we cannot be using the lock it provides
1171                 mgr.assert_called()
1172
1173     def test_no_cache_when_stdin(self) -> None:
1174         mode = DEFAULT_MODE
1175         with cache_dir():
1176             result = CliRunner().invoke(
1177                 black.main, ["-"], input=BytesIO(b"print('hello')")
1178             )
1179             self.assertEqual(result.exit_code, 0)
1180             cache_file = get_cache_file(mode)
1181             self.assertFalse(cache_file.exists())
1182
1183     def test_read_cache_no_cachefile(self) -> None:
1184         mode = DEFAULT_MODE
1185         with cache_dir():
1186             self.assertEqual(black.read_cache(mode), {})
1187
1188     def test_write_cache_read_cache(self) -> None:
1189         mode = DEFAULT_MODE
1190         with cache_dir() as workspace:
1191             src = (workspace / "test.py").resolve()
1192             src.touch()
1193             black.write_cache({}, [src], mode)
1194             cache = black.read_cache(mode)
1195             self.assertIn(str(src), cache)
1196             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1197
1198     def test_filter_cached(self) -> None:
1199         with TemporaryDirectory() as workspace:
1200             path = Path(workspace)
1201             uncached = (path / "uncached").resolve()
1202             cached = (path / "cached").resolve()
1203             cached_but_changed = (path / "changed").resolve()
1204             uncached.touch()
1205             cached.touch()
1206             cached_but_changed.touch()
1207             cache = {
1208                 str(cached): black.get_cache_info(cached),
1209                 str(cached_but_changed): (0.0, 0),
1210             }
1211             todo, done = black.filter_cached(
1212                 cache, {uncached, cached, cached_but_changed}
1213             )
1214             self.assertEqual(todo, {uncached, cached_but_changed})
1215             self.assertEqual(done, {cached})
1216
1217     def test_write_cache_creates_directory_if_needed(self) -> None:
1218         mode = DEFAULT_MODE
1219         with cache_dir(exists=False) as workspace:
1220             self.assertFalse(workspace.exists())
1221             black.write_cache({}, [], mode)
1222             self.assertTrue(workspace.exists())
1223
1224     @event_loop()
1225     def test_failed_formatting_does_not_get_cached(self) -> None:
1226         mode = DEFAULT_MODE
1227         with cache_dir() as workspace, patch(
1228             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1229         ):
1230             failing = (workspace / "failing.py").resolve()
1231             with failing.open("w") as fobj:
1232                 fobj.write("not actually python")
1233             clean = (workspace / "clean.py").resolve()
1234             with clean.open("w") as fobj:
1235                 fobj.write('print("hello")\n')
1236             self.invokeBlack([str(workspace)], exit_code=123)
1237             cache = black.read_cache(mode)
1238             self.assertNotIn(str(failing), cache)
1239             self.assertIn(str(clean), cache)
1240
1241     def test_write_cache_write_fail(self) -> None:
1242         mode = DEFAULT_MODE
1243         with cache_dir(), patch.object(Path, "open") as mock:
1244             mock.side_effect = OSError
1245             black.write_cache({}, [], mode)
1246
1247     @event_loop()
1248     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1249     def test_works_in_mono_process_only_environment(self) -> None:
1250         with cache_dir() as workspace:
1251             for f in [
1252                 (workspace / "one.py").resolve(),
1253                 (workspace / "two.py").resolve(),
1254             ]:
1255                 f.write_text('print("hello")\n')
1256             self.invokeBlack([str(workspace)])
1257
1258     @event_loop()
1259     def test_check_diff_use_together(self) -> None:
1260         with cache_dir():
1261             # Files which will be reformatted.
1262             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1263             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1264             # Files which will not be reformatted.
1265             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1266             self.invokeBlack([str(src2), "--diff", "--check"])
1267             # Multi file command.
1268             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1269
1270     def test_no_files(self) -> None:
1271         with cache_dir():
1272             # Without an argument, black exits with error code 0.
1273             self.invokeBlack([])
1274
1275     def test_broken_symlink(self) -> None:
1276         with cache_dir() as workspace:
1277             symlink = workspace / "broken_link.py"
1278             try:
1279                 symlink.symlink_to("nonexistent.py")
1280             except OSError as e:
1281                 self.skipTest(f"Can't create symlinks: {e}")
1282             self.invokeBlack([str(workspace.resolve())])
1283
1284     def test_read_cache_line_lengths(self) -> None:
1285         mode = DEFAULT_MODE
1286         short_mode = replace(DEFAULT_MODE, line_length=1)
1287         with cache_dir() as workspace:
1288             path = (workspace / "file.py").resolve()
1289             path.touch()
1290             black.write_cache({}, [path], mode)
1291             one = black.read_cache(mode)
1292             self.assertIn(str(path), one)
1293             two = black.read_cache(short_mode)
1294             self.assertNotIn(str(path), two)
1295
1296     def test_single_file_force_pyi(self) -> None:
1297         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1298         contents, expected = read_data("force_pyi")
1299         with cache_dir() as workspace:
1300             path = (workspace / "file.py").resolve()
1301             with open(path, "w") as fh:
1302                 fh.write(contents)
1303             self.invokeBlack([str(path), "--pyi"])
1304             with open(path, "r") as fh:
1305                 actual = fh.read()
1306             # verify cache with --pyi is separate
1307             pyi_cache = black.read_cache(pyi_mode)
1308             self.assertIn(str(path), pyi_cache)
1309             normal_cache = black.read_cache(DEFAULT_MODE)
1310             self.assertNotIn(str(path), normal_cache)
1311         self.assertFormatEqual(expected, actual)
1312         black.assert_equivalent(contents, actual)
1313         black.assert_stable(contents, actual, pyi_mode)
1314
1315     @event_loop()
1316     def test_multi_file_force_pyi(self) -> None:
1317         reg_mode = DEFAULT_MODE
1318         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1319         contents, expected = read_data("force_pyi")
1320         with cache_dir() as workspace:
1321             paths = [
1322                 (workspace / "file1.py").resolve(),
1323                 (workspace / "file2.py").resolve(),
1324             ]
1325             for path in paths:
1326                 with open(path, "w") as fh:
1327                     fh.write(contents)
1328             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1329             for path in paths:
1330                 with open(path, "r") as fh:
1331                     actual = fh.read()
1332                 self.assertEqual(actual, expected)
1333             # verify cache with --pyi is separate
1334             pyi_cache = black.read_cache(pyi_mode)
1335             normal_cache = black.read_cache(reg_mode)
1336             for path in paths:
1337                 self.assertIn(str(path), pyi_cache)
1338                 self.assertNotIn(str(path), normal_cache)
1339
1340     def test_pipe_force_pyi(self) -> None:
1341         source, expected = read_data("force_pyi")
1342         result = CliRunner().invoke(
1343             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1344         )
1345         self.assertEqual(result.exit_code, 0)
1346         actual = result.output
1347         self.assertFormatEqual(actual, expected)
1348
1349     def test_single_file_force_py36(self) -> None:
1350         reg_mode = DEFAULT_MODE
1351         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1352         source, expected = read_data("force_py36")
1353         with cache_dir() as workspace:
1354             path = (workspace / "file.py").resolve()
1355             with open(path, "w") as fh:
1356                 fh.write(source)
1357             self.invokeBlack([str(path), *PY36_ARGS])
1358             with open(path, "r") as fh:
1359                 actual = fh.read()
1360             # verify cache with --target-version is separate
1361             py36_cache = black.read_cache(py36_mode)
1362             self.assertIn(str(path), py36_cache)
1363             normal_cache = black.read_cache(reg_mode)
1364             self.assertNotIn(str(path), normal_cache)
1365         self.assertEqual(actual, expected)
1366
1367     @event_loop()
1368     def test_multi_file_force_py36(self) -> None:
1369         reg_mode = DEFAULT_MODE
1370         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1371         source, expected = read_data("force_py36")
1372         with cache_dir() as workspace:
1373             paths = [
1374                 (workspace / "file1.py").resolve(),
1375                 (workspace / "file2.py").resolve(),
1376             ]
1377             for path in paths:
1378                 with open(path, "w") as fh:
1379                     fh.write(source)
1380             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1381             for path in paths:
1382                 with open(path, "r") as fh:
1383                     actual = fh.read()
1384                 self.assertEqual(actual, expected)
1385             # verify cache with --target-version is separate
1386             pyi_cache = black.read_cache(py36_mode)
1387             normal_cache = black.read_cache(reg_mode)
1388             for path in paths:
1389                 self.assertIn(str(path), pyi_cache)
1390                 self.assertNotIn(str(path), normal_cache)
1391
1392     def test_pipe_force_py36(self) -> None:
1393         source, expected = read_data("force_py36")
1394         result = CliRunner().invoke(
1395             black.main,
1396             ["-", "-q", "--target-version=py36"],
1397             input=BytesIO(source.encode("utf8")),
1398         )
1399         self.assertEqual(result.exit_code, 0)
1400         actual = result.output
1401         self.assertFormatEqual(actual, expected)
1402
1403     def test_include_exclude(self) -> None:
1404         path = THIS_DIR / "data" / "include_exclude_tests"
1405         include = re.compile(r"\.pyi?$")
1406         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1407         report = black.Report()
1408         gitignore = PathSpec.from_lines("gitwildmatch", [])
1409         sources: List[Path] = []
1410         expected = [
1411             Path(path / "b/dont_exclude/a.py"),
1412             Path(path / "b/dont_exclude/a.pyi"),
1413         ]
1414         this_abs = THIS_DIR.resolve()
1415         sources.extend(
1416             black.gen_python_files(
1417                 path.iterdir(),
1418                 this_abs,
1419                 include,
1420                 exclude,
1421                 None,
1422                 None,
1423                 report,
1424                 gitignore,
1425             )
1426         )
1427         self.assertEqual(sorted(expected), sorted(sources))
1428
1429     def test_gitingore_used_as_default(self) -> None:
1430         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1431         include = re.compile(r"\.pyi?$")
1432         extend_exclude = re.compile(r"/exclude/")
1433         src = str(path / "b/")
1434         report = black.Report()
1435         expected: List[Path] = [
1436             path / "b/.definitely_exclude/a.py",
1437             path / "b/.definitely_exclude/a.pyi",
1438         ]
1439         sources = list(
1440             black.get_sources(
1441                 ctx=FakeContext(),
1442                 src=(src,),
1443                 quiet=True,
1444                 verbose=False,
1445                 include=include,
1446                 exclude=None,
1447                 extend_exclude=extend_exclude,
1448                 force_exclude=None,
1449                 report=report,
1450                 stdin_filename=None,
1451             )
1452         )
1453         self.assertEqual(sorted(expected), sorted(sources))
1454
1455     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1456     def test_exclude_for_issue_1572(self) -> None:
1457         # Exclude shouldn't touch files that were explicitly given to Black through the
1458         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1459         # https://github.com/psf/black/issues/1572
1460         path = THIS_DIR / "data" / "include_exclude_tests"
1461         include = ""
1462         exclude = r"/exclude/|a\.py"
1463         src = str(path / "b/exclude/a.py")
1464         report = black.Report()
1465         expected = [Path(path / "b/exclude/a.py")]
1466         sources = list(
1467             black.get_sources(
1468                 ctx=FakeContext(),
1469                 src=(src,),
1470                 quiet=True,
1471                 verbose=False,
1472                 include=re.compile(include),
1473                 exclude=re.compile(exclude),
1474                 extend_exclude=None,
1475                 force_exclude=None,
1476                 report=report,
1477                 stdin_filename=None,
1478             )
1479         )
1480         self.assertEqual(sorted(expected), sorted(sources))
1481
1482     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1483     def test_get_sources_with_stdin(self) -> None:
1484         include = ""
1485         exclude = r"/exclude/|a\.py"
1486         src = "-"
1487         report = black.Report()
1488         expected = [Path("-")]
1489         sources = list(
1490             black.get_sources(
1491                 ctx=FakeContext(),
1492                 src=(src,),
1493                 quiet=True,
1494                 verbose=False,
1495                 include=re.compile(include),
1496                 exclude=re.compile(exclude),
1497                 extend_exclude=None,
1498                 force_exclude=None,
1499                 report=report,
1500                 stdin_filename=None,
1501             )
1502         )
1503         self.assertEqual(sorted(expected), sorted(sources))
1504
1505     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1506     def test_get_sources_with_stdin_filename(self) -> None:
1507         include = ""
1508         exclude = r"/exclude/|a\.py"
1509         src = "-"
1510         report = black.Report()
1511         stdin_filename = str(THIS_DIR / "data/collections.py")
1512         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1513         sources = list(
1514             black.get_sources(
1515                 ctx=FakeContext(),
1516                 src=(src,),
1517                 quiet=True,
1518                 verbose=False,
1519                 include=re.compile(include),
1520                 exclude=re.compile(exclude),
1521                 extend_exclude=None,
1522                 force_exclude=None,
1523                 report=report,
1524                 stdin_filename=stdin_filename,
1525             )
1526         )
1527         self.assertEqual(sorted(expected), sorted(sources))
1528
1529     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1530     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1531         # Exclude shouldn't exclude stdin_filename since it is mimicing the
1532         # file being passed directly. This is the same as
1533         # test_exclude_for_issue_1572
1534         path = THIS_DIR / "data" / "include_exclude_tests"
1535         include = ""
1536         exclude = r"/exclude/|a\.py"
1537         src = "-"
1538         report = black.Report()
1539         stdin_filename = str(path / "b/exclude/a.py")
1540         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1541         sources = list(
1542             black.get_sources(
1543                 ctx=FakeContext(),
1544                 src=(src,),
1545                 quiet=True,
1546                 verbose=False,
1547                 include=re.compile(include),
1548                 exclude=re.compile(exclude),
1549                 extend_exclude=None,
1550                 force_exclude=None,
1551                 report=report,
1552                 stdin_filename=stdin_filename,
1553             )
1554         )
1555         self.assertEqual(sorted(expected), sorted(sources))
1556
1557     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1558     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1559         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1560         # file being passed directly. This is the same as
1561         # test_exclude_for_issue_1572
1562         path = THIS_DIR / "data" / "include_exclude_tests"
1563         include = ""
1564         extend_exclude = r"/exclude/|a\.py"
1565         src = "-"
1566         report = black.Report()
1567         stdin_filename = str(path / "b/exclude/a.py")
1568         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1569         sources = list(
1570             black.get_sources(
1571                 ctx=FakeContext(),
1572                 src=(src,),
1573                 quiet=True,
1574                 verbose=False,
1575                 include=re.compile(include),
1576                 exclude=re.compile(""),
1577                 extend_exclude=re.compile(extend_exclude),
1578                 force_exclude=None,
1579                 report=report,
1580                 stdin_filename=stdin_filename,
1581             )
1582         )
1583         self.assertEqual(sorted(expected), sorted(sources))
1584
1585     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1586     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1587         # Force exclude should exclude the file when passing it through
1588         # stdin_filename
1589         path = THIS_DIR / "data" / "include_exclude_tests"
1590         include = ""
1591         force_exclude = r"/exclude/|a\.py"
1592         src = "-"
1593         report = black.Report()
1594         stdin_filename = str(path / "b/exclude/a.py")
1595         sources = list(
1596             black.get_sources(
1597                 ctx=FakeContext(),
1598                 src=(src,),
1599                 quiet=True,
1600                 verbose=False,
1601                 include=re.compile(include),
1602                 exclude=re.compile(""),
1603                 extend_exclude=None,
1604                 force_exclude=re.compile(force_exclude),
1605                 report=report,
1606                 stdin_filename=stdin_filename,
1607             )
1608         )
1609         self.assertEqual([], sorted(sources))
1610
1611     def test_reformat_one_with_stdin(self) -> None:
1612         with patch(
1613             "black.format_stdin_to_stdout",
1614             return_value=lambda *args, **kwargs: black.Changed.YES,
1615         ) as fsts:
1616             report = MagicMock()
1617             path = Path("-")
1618             black.reformat_one(
1619                 path,
1620                 fast=True,
1621                 write_back=black.WriteBack.YES,
1622                 mode=DEFAULT_MODE,
1623                 report=report,
1624             )
1625             fsts.assert_called_once()
1626             report.done.assert_called_with(path, black.Changed.YES)
1627
1628     def test_reformat_one_with_stdin_filename(self) -> None:
1629         with patch(
1630             "black.format_stdin_to_stdout",
1631             return_value=lambda *args, **kwargs: black.Changed.YES,
1632         ) as fsts:
1633             report = MagicMock()
1634             p = "foo.py"
1635             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1636             expected = Path(p)
1637             black.reformat_one(
1638                 path,
1639                 fast=True,
1640                 write_back=black.WriteBack.YES,
1641                 mode=DEFAULT_MODE,
1642                 report=report,
1643             )
1644             fsts.assert_called_once_with(
1645                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1646             )
1647             # __BLACK_STDIN_FILENAME__ should have been stripped
1648             report.done.assert_called_with(expected, black.Changed.YES)
1649
1650     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1651         with patch(
1652             "black.format_stdin_to_stdout",
1653             return_value=lambda *args, **kwargs: black.Changed.YES,
1654         ) as fsts:
1655             report = MagicMock()
1656             p = "foo.pyi"
1657             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1658             expected = Path(p)
1659             black.reformat_one(
1660                 path,
1661                 fast=True,
1662                 write_back=black.WriteBack.YES,
1663                 mode=DEFAULT_MODE,
1664                 report=report,
1665             )
1666             fsts.assert_called_once_with(
1667                 fast=True,
1668                 write_back=black.WriteBack.YES,
1669                 mode=replace(DEFAULT_MODE, is_pyi=True),
1670             )
1671             # __BLACK_STDIN_FILENAME__ should have been stripped
1672             report.done.assert_called_with(expected, black.Changed.YES)
1673
1674     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1675         with patch(
1676             "black.format_stdin_to_stdout",
1677             return_value=lambda *args, **kwargs: black.Changed.YES,
1678         ) as fsts:
1679             report = MagicMock()
1680             # Even with an existing file, since we are forcing stdin, black
1681             # should output to stdout and not modify the file inplace
1682             p = Path(str(THIS_DIR / "data/collections.py"))
1683             # Make sure is_file actually returns True
1684             self.assertTrue(p.is_file())
1685             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1686             expected = Path(p)
1687             black.reformat_one(
1688                 path,
1689                 fast=True,
1690                 write_back=black.WriteBack.YES,
1691                 mode=DEFAULT_MODE,
1692                 report=report,
1693             )
1694             fsts.assert_called_once()
1695             # __BLACK_STDIN_FILENAME__ should have been stripped
1696             report.done.assert_called_with(expected, black.Changed.YES)
1697
1698     def test_gitignore_exclude(self) -> None:
1699         path = THIS_DIR / "data" / "include_exclude_tests"
1700         include = re.compile(r"\.pyi?$")
1701         exclude = re.compile(r"")
1702         report = black.Report()
1703         gitignore = PathSpec.from_lines(
1704             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1705         )
1706         sources: List[Path] = []
1707         expected = [
1708             Path(path / "b/dont_exclude/a.py"),
1709             Path(path / "b/dont_exclude/a.pyi"),
1710         ]
1711         this_abs = THIS_DIR.resolve()
1712         sources.extend(
1713             black.gen_python_files(
1714                 path.iterdir(),
1715                 this_abs,
1716                 include,
1717                 exclude,
1718                 None,
1719                 None,
1720                 report,
1721                 gitignore,
1722             )
1723         )
1724         self.assertEqual(sorted(expected), sorted(sources))
1725
1726     def test_empty_include(self) -> None:
1727         path = THIS_DIR / "data" / "include_exclude_tests"
1728         report = black.Report()
1729         gitignore = PathSpec.from_lines("gitwildmatch", [])
1730         empty = re.compile(r"")
1731         sources: List[Path] = []
1732         expected = [
1733             Path(path / "b/exclude/a.pie"),
1734             Path(path / "b/exclude/a.py"),
1735             Path(path / "b/exclude/a.pyi"),
1736             Path(path / "b/dont_exclude/a.pie"),
1737             Path(path / "b/dont_exclude/a.py"),
1738             Path(path / "b/dont_exclude/a.pyi"),
1739             Path(path / "b/.definitely_exclude/a.pie"),
1740             Path(path / "b/.definitely_exclude/a.py"),
1741             Path(path / "b/.definitely_exclude/a.pyi"),
1742             Path(path / ".gitignore"),
1743             Path(path / "pyproject.toml"),
1744         ]
1745         this_abs = THIS_DIR.resolve()
1746         sources.extend(
1747             black.gen_python_files(
1748                 path.iterdir(),
1749                 this_abs,
1750                 empty,
1751                 re.compile(black.DEFAULT_EXCLUDES),
1752                 None,
1753                 None,
1754                 report,
1755                 gitignore,
1756             )
1757         )
1758         self.assertEqual(sorted(expected), sorted(sources))
1759
1760     def test_extend_exclude(self) -> None:
1761         path = THIS_DIR / "data" / "include_exclude_tests"
1762         report = black.Report()
1763         gitignore = PathSpec.from_lines("gitwildmatch", [])
1764         sources: List[Path] = []
1765         expected = [
1766             Path(path / "b/exclude/a.py"),
1767             Path(path / "b/dont_exclude/a.py"),
1768         ]
1769         this_abs = THIS_DIR.resolve()
1770         sources.extend(
1771             black.gen_python_files(
1772                 path.iterdir(),
1773                 this_abs,
1774                 re.compile(black.DEFAULT_INCLUDES),
1775                 re.compile(r"\.pyi$"),
1776                 re.compile(r"\.definitely_exclude"),
1777                 None,
1778                 report,
1779                 gitignore,
1780             )
1781         )
1782         self.assertEqual(sorted(expected), sorted(sources))
1783
1784     def test_invalid_cli_regex(self) -> None:
1785         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1786             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1787
1788     def test_preserves_line_endings(self) -> None:
1789         with TemporaryDirectory() as workspace:
1790             test_file = Path(workspace) / "test.py"
1791             for nl in ["\n", "\r\n"]:
1792                 contents = nl.join(["def f(  ):", "    pass"])
1793                 test_file.write_bytes(contents.encode())
1794                 ff(test_file, write_back=black.WriteBack.YES)
1795                 updated_contents: bytes = test_file.read_bytes()
1796                 self.assertIn(nl.encode(), updated_contents)
1797                 if nl == "\n":
1798                     self.assertNotIn(b"\r\n", updated_contents)
1799
1800     def test_preserves_line_endings_via_stdin(self) -> None:
1801         for nl in ["\n", "\r\n"]:
1802             contents = nl.join(["def f(  ):", "    pass"])
1803             runner = BlackRunner()
1804             result = runner.invoke(
1805                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1806             )
1807             self.assertEqual(result.exit_code, 0)
1808             output = runner.stdout_bytes
1809             self.assertIn(nl.encode("utf8"), output)
1810             if nl == "\n":
1811                 self.assertNotIn(b"\r\n", output)
1812
1813     def test_assert_equivalent_different_asts(self) -> None:
1814         with self.assertRaises(AssertionError):
1815             black.assert_equivalent("{}", "None")
1816
1817     def test_symlink_out_of_root_directory(self) -> None:
1818         path = MagicMock()
1819         root = THIS_DIR.resolve()
1820         child = MagicMock()
1821         include = re.compile(black.DEFAULT_INCLUDES)
1822         exclude = re.compile(black.DEFAULT_EXCLUDES)
1823         report = black.Report()
1824         gitignore = PathSpec.from_lines("gitwildmatch", [])
1825         # `child` should behave like a symlink which resolved path is clearly
1826         # outside of the `root` directory.
1827         path.iterdir.return_value = [child]
1828         child.resolve.return_value = Path("/a/b/c")
1829         child.as_posix.return_value = "/a/b/c"
1830         child.is_symlink.return_value = True
1831         try:
1832             list(
1833                 black.gen_python_files(
1834                     path.iterdir(),
1835                     root,
1836                     include,
1837                     exclude,
1838                     None,
1839                     None,
1840                     report,
1841                     gitignore,
1842                 )
1843             )
1844         except ValueError as ve:
1845             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1846         path.iterdir.assert_called_once()
1847         child.resolve.assert_called_once()
1848         child.is_symlink.assert_called_once()
1849         # `child` should behave like a strange file which resolved path is clearly
1850         # outside of the `root` directory.
1851         child.is_symlink.return_value = False
1852         with self.assertRaises(ValueError):
1853             list(
1854                 black.gen_python_files(
1855                     path.iterdir(),
1856                     root,
1857                     include,
1858                     exclude,
1859                     None,
1860                     None,
1861                     report,
1862                     gitignore,
1863                 )
1864             )
1865         path.iterdir.assert_called()
1866         self.assertEqual(path.iterdir.call_count, 2)
1867         child.resolve.assert_called()
1868         self.assertEqual(child.resolve.call_count, 2)
1869         child.is_symlink.assert_called()
1870         self.assertEqual(child.is_symlink.call_count, 2)
1871
1872     def test_shhh_click(self) -> None:
1873         try:
1874             from click import _unicodefun  # type: ignore
1875         except ModuleNotFoundError:
1876             self.skipTest("Incompatible Click version")
1877         if not hasattr(_unicodefun, "_verify_python3_env"):
1878             self.skipTest("Incompatible Click version")
1879         # First, let's see if Click is crashing with a preferred ASCII charset.
1880         with patch("locale.getpreferredencoding") as gpe:
1881             gpe.return_value = "ASCII"
1882             with self.assertRaises(RuntimeError):
1883                 _unicodefun._verify_python3_env()
1884         # Now, let's silence Click...
1885         black.patch_click()
1886         # ...and confirm it's silent.
1887         with patch("locale.getpreferredencoding") as gpe:
1888             gpe.return_value = "ASCII"
1889             try:
1890                 _unicodefun._verify_python3_env()
1891             except RuntimeError as re:
1892                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1893
1894     def test_root_logger_not_used_directly(self) -> None:
1895         def fail(*args: Any, **kwargs: Any) -> None:
1896             self.fail("Record created with root logger")
1897
1898         with patch.multiple(
1899             logging.root,
1900             debug=fail,
1901             info=fail,
1902             warning=fail,
1903             error=fail,
1904             critical=fail,
1905             log=fail,
1906         ):
1907             ff(THIS_FILE)
1908
1909     def test_invalid_config_return_code(self) -> None:
1910         tmp_file = Path(black.dump_to_file())
1911         try:
1912             tmp_config = Path(black.dump_to_file())
1913             tmp_config.unlink()
1914             args = ["--config", str(tmp_config), str(tmp_file)]
1915             self.invokeBlack(args, exit_code=2, ignore_config=False)
1916         finally:
1917             tmp_file.unlink()
1918
1919     def test_parse_pyproject_toml(self) -> None:
1920         test_toml_file = THIS_DIR / "test.toml"
1921         config = black.parse_pyproject_toml(str(test_toml_file))
1922         self.assertEqual(config["verbose"], 1)
1923         self.assertEqual(config["check"], "no")
1924         self.assertEqual(config["diff"], "y")
1925         self.assertEqual(config["color"], True)
1926         self.assertEqual(config["line_length"], 79)
1927         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1928         self.assertEqual(config["exclude"], r"\.pyi?$")
1929         self.assertEqual(config["include"], r"\.py?$")
1930
1931     def test_read_pyproject_toml(self) -> None:
1932         test_toml_file = THIS_DIR / "test.toml"
1933         fake_ctx = FakeContext()
1934         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1935         config = fake_ctx.default_map
1936         self.assertEqual(config["verbose"], "1")
1937         self.assertEqual(config["check"], "no")
1938         self.assertEqual(config["diff"], "y")
1939         self.assertEqual(config["color"], "True")
1940         self.assertEqual(config["line_length"], "79")
1941         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1942         self.assertEqual(config["exclude"], r"\.pyi?$")
1943         self.assertEqual(config["include"], r"\.py?$")
1944
1945     def test_find_project_root(self) -> None:
1946         with TemporaryDirectory() as workspace:
1947             root = Path(workspace)
1948             test_dir = root / "test"
1949             test_dir.mkdir()
1950
1951             src_dir = root / "src"
1952             src_dir.mkdir()
1953
1954             root_pyproject = root / "pyproject.toml"
1955             root_pyproject.touch()
1956             src_pyproject = src_dir / "pyproject.toml"
1957             src_pyproject.touch()
1958             src_python = src_dir / "foo.py"
1959             src_python.touch()
1960
1961             self.assertEqual(
1962                 black.find_project_root((src_dir, test_dir)), root.resolve()
1963             )
1964             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1965             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1966
1967     @patch(
1968         "black.files.find_user_pyproject_toml",
1969         black.files.find_user_pyproject_toml.__wrapped__,
1970     )
1971     def test_find_user_pyproject_toml_linux(self) -> None:
1972         if system() == "Windows":
1973             return
1974
1975         # Test if XDG_CONFIG_HOME is checked
1976         with TemporaryDirectory() as workspace:
1977             tmp_user_config = Path(workspace) / "black"
1978             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1979                 self.assertEqual(
1980                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1981                 )
1982
1983         # Test fallback for XDG_CONFIG_HOME
1984         with patch.dict("os.environ"):
1985             os.environ.pop("XDG_CONFIG_HOME", None)
1986             fallback_user_config = Path("~/.config").expanduser() / "black"
1987             self.assertEqual(
1988                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1989             )
1990
1991     def test_find_user_pyproject_toml_windows(self) -> None:
1992         if system() != "Windows":
1993             return
1994
1995         user_config_path = Path.home() / ".black"
1996         self.assertEqual(
1997             black.files.find_user_pyproject_toml(), user_config_path.resolve()
1998         )
1999
2000     def test_bpo_33660_workaround(self) -> None:
2001         if system() == "Windows":
2002             return
2003
2004         # https://bugs.python.org/issue33660
2005
2006         old_cwd = Path.cwd()
2007         try:
2008             root = Path("/")
2009             os.chdir(str(root))
2010             path = Path("workspace") / "project"
2011             report = black.Report(verbose=True)
2012             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2013             self.assertEqual(normalized_path, "workspace/project")
2014         finally:
2015             os.chdir(str(old_cwd))
2016
2017     def test_newline_comment_interaction(self) -> None:
2018         source = "class A:\\\r\n# type: ignore\n pass\n"
2019         output = black.format_str(source, mode=DEFAULT_MODE)
2020         black.assert_stable(source, output, mode=DEFAULT_MODE)
2021
2022     def test_bpo_2142_workaround(self) -> None:
2023
2024         # https://bugs.python.org/issue2142
2025
2026         source, _ = read_data("missing_final_newline.py")
2027         # read_data adds a trailing newline
2028         source = source.rstrip()
2029         expected, _ = read_data("missing_final_newline.diff")
2030         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2031         diff_header = re.compile(
2032             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2033             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2034         )
2035         try:
2036             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2037             self.assertEqual(result.exit_code, 0)
2038         finally:
2039             os.unlink(tmp_file)
2040         actual = result.output
2041         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2042         self.assertEqual(actual, expected)
2043
2044     @pytest.mark.python2
2045     def test_docstring_reformat_for_py27(self) -> None:
2046         """
2047         Check that stripping trailing whitespace from Python 2 docstrings
2048         doesn't trigger a "not equivalent to source" error
2049         """
2050         source = (
2051             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2052         )
2053         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2054
2055         result = CliRunner().invoke(
2056             black.main,
2057             ["-", "-q", "--target-version=py27"],
2058             input=BytesIO(source),
2059         )
2060
2061         self.assertEqual(result.exit_code, 0)
2062         actual = result.output
2063         self.assertFormatEqual(actual, expected)
2064
2065
2066 with open(black.__file__, "r", encoding="utf-8") as _bf:
2067     black_source_lines = _bf.readlines()
2068
2069
2070 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2071     """Show function calls `from black/__init__.py` as they happen.
2072
2073     Register this with `sys.settrace()` in a test you're debugging.
2074     """
2075     if event != "call":
2076         return tracefunc
2077
2078     stack = len(inspect.stack()) - 19
2079     stack *= 2
2080     filename = frame.f_code.co_filename
2081     lineno = frame.f_lineno
2082     func_sig_lineno = lineno - 1
2083     funcname = black_source_lines[func_sig_lineno].strip()
2084     while funcname.startswith("@"):
2085         func_sig_lineno += 1
2086         funcname = black_source_lines[func_sig_lineno].strip()
2087     if "black/__init__.py" in filename:
2088         print(f"{' ' * stack}{lineno}:{funcname}")
2089     return tracefunc
2090
2091
2092 if __name__ == "__main__":
2093     unittest.main(module="test_black")