]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add optional uvloop import (#2258)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 import inspect
9 from io import BytesIO
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     Callable,
20     Dict,
21     List,
22     Iterator,
23     TypeVar,
24 )
25 import pytest
26 import unittest
27 from unittest.mock import patch, MagicMock
28
29 import click
30 from click import unstyle
31 from click.testing import CliRunner
32
33 import black
34 from black import Feature, TargetVersion
35 from black.cache import get_cache_file
36 from black.debug import DebugVisitor
37 from black.report import Report
38 import black.files
39
40 from pathspec import PathSpec
41
42 # Import other test classes
43 from tests.util import (
44     THIS_DIR,
45     read_data,
46     DETERMINISTIC_HEADER,
47     BlackBaseTestCase,
48     DEFAULT_MODE,
49     fs,
50     ff,
51     dump_to_stderr,
52 )
53
54
55 THIS_FILE = Path(__file__)
56 PY36_VERSIONS = {
57     TargetVersion.PY36,
58     TargetVersion.PY37,
59     TargetVersion.PY38,
60     TargetVersion.PY39,
61 }
62 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
63 T = TypeVar("T")
64 R = TypeVar("R")
65
66
67 @contextmanager
68 def cache_dir(exists: bool = True) -> Iterator[Path]:
69     with TemporaryDirectory() as workspace:
70         cache_dir = Path(workspace)
71         if not exists:
72             cache_dir = cache_dir / "new"
73         with patch("black.cache.CACHE_DIR", cache_dir):
74             yield cache_dir
75
76
77 @contextmanager
78 def event_loop() -> Iterator[None]:
79     policy = asyncio.get_event_loop_policy()
80     loop = policy.new_event_loop()
81     asyncio.set_event_loop(loop)
82     try:
83         yield
84
85     finally:
86         loop.close()
87
88
89 class FakeContext(click.Context):
90     """A fake click Context for when calling functions that need it."""
91
92     def __init__(self) -> None:
93         self.default_map: Dict[str, Any] = {}
94
95
96 class FakeParameter(click.Parameter):
97     """A fake click Parameter for when calling functions that need it."""
98
99     def __init__(self) -> None:
100         pass
101
102
103 class BlackRunner(CliRunner):
104     """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
105
106     def __init__(self) -> None:
107         super().__init__(mix_stderr=False)
108
109
110 class BlackTestCase(BlackBaseTestCase):
111     def invokeBlack(
112         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
113     ) -> None:
114         runner = BlackRunner()
115         if ignore_config:
116             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
117         result = runner.invoke(black.main, args)
118         self.assertEqual(
119             result.exit_code,
120             exit_code,
121             msg=(
122                 f"Failed with args: {args}\n"
123                 f"stdout: {result.stdout_bytes.decode()!r}\n"
124                 f"stderr: {result.stderr_bytes.decode()!r}\n"
125                 f"exception: {result.exception}"
126             ),
127         )
128
129     @patch("black.dump_to_file", dump_to_stderr)
130     def test_empty(self) -> None:
131         source = expected = ""
132         actual = fs(source)
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, DEFAULT_MODE)
136
137     def test_empty_ff(self) -> None:
138         expected = ""
139         tmp_file = Path(black.dump_to_file())
140         try:
141             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
142             with open(tmp_file, encoding="utf8") as f:
143                 actual = f.read()
144         finally:
145             os.unlink(tmp_file)
146         self.assertFormatEqual(expected, actual)
147
148     def test_piping(self) -> None:
149         source, expected = read_data("src/black/__init__", data=False)
150         result = BlackRunner().invoke(
151             black.main,
152             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
153             input=BytesIO(source.encode("utf8")),
154         )
155         self.assertEqual(result.exit_code, 0)
156         self.assertFormatEqual(expected, result.output)
157         if source != result.output:
158             black.assert_equivalent(source, result.output)
159             black.assert_stable(source, result.output, DEFAULT_MODE)
160
161     def test_piping_diff(self) -> None:
162         diff_header = re.compile(
163             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
164             r"\+\d\d\d\d"
165         )
166         source, _ = read_data("expression.py")
167         expected, _ = read_data("expression.diff")
168         config = THIS_DIR / "data" / "empty_pyproject.toml"
169         args = [
170             "-",
171             "--fast",
172             f"--line-length={black.DEFAULT_LINE_LENGTH}",
173             "--diff",
174             f"--config={config}",
175         ]
176         result = BlackRunner().invoke(
177             black.main, args, input=BytesIO(source.encode("utf8"))
178         )
179         self.assertEqual(result.exit_code, 0)
180         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
181         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
182         self.assertEqual(expected, actual)
183
184     def test_piping_diff_with_color(self) -> None:
185         source, _ = read_data("expression.py")
186         config = THIS_DIR / "data" / "empty_pyproject.toml"
187         args = [
188             "-",
189             "--fast",
190             f"--line-length={black.DEFAULT_LINE_LENGTH}",
191             "--diff",
192             "--color",
193             f"--config={config}",
194         ]
195         result = BlackRunner().invoke(
196             black.main, args, input=BytesIO(source.encode("utf8"))
197         )
198         actual = result.output
199         # Again, the contents are checked in a different test, so only look for colors.
200         self.assertIn("\033[1;37m", actual)
201         self.assertIn("\033[36m", actual)
202         self.assertIn("\033[32m", actual)
203         self.assertIn("\033[31m", actual)
204         self.assertIn("\033[0m", actual)
205
206     @patch("black.dump_to_file", dump_to_stderr)
207     def _test_wip(self) -> None:
208         source, expected = read_data("wip")
209         sys.settrace(tracefunc)
210         mode = replace(
211             DEFAULT_MODE,
212             experimental_string_processing=False,
213             target_versions={black.TargetVersion.PY38},
214         )
215         actual = fs(source, mode=mode)
216         sys.settrace(None)
217         self.assertFormatEqual(expected, actual)
218         black.assert_equivalent(source, actual)
219         black.assert_stable(source, actual, black.FileMode())
220
221     @unittest.expectedFailure
222     @patch("black.dump_to_file", dump_to_stderr)
223     def test_trailing_comma_optional_parens_stability1(self) -> None:
224         source, _expected = read_data("trailing_comma_optional_parens1")
225         actual = fs(source)
226         black.assert_stable(source, actual, DEFAULT_MODE)
227
228     @unittest.expectedFailure
229     @patch("black.dump_to_file", dump_to_stderr)
230     def test_trailing_comma_optional_parens_stability2(self) -> None:
231         source, _expected = read_data("trailing_comma_optional_parens2")
232         actual = fs(source)
233         black.assert_stable(source, actual, DEFAULT_MODE)
234
235     @unittest.expectedFailure
236     @patch("black.dump_to_file", dump_to_stderr)
237     def test_trailing_comma_optional_parens_stability3(self) -> None:
238         source, _expected = read_data("trailing_comma_optional_parens3")
239         actual = fs(source)
240         black.assert_stable(source, actual, DEFAULT_MODE)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
244         source, _expected = read_data("trailing_comma_optional_parens1")
245         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
246         black.assert_stable(source, actual, DEFAULT_MODE)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
250         source, _expected = read_data("trailing_comma_optional_parens2")
251         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
252         black.assert_stable(source, actual, DEFAULT_MODE)
253
254     @patch("black.dump_to_file", dump_to_stderr)
255     def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
256         source, _expected = read_data("trailing_comma_optional_parens3")
257         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
258         black.assert_stable(source, actual, DEFAULT_MODE)
259
260     @patch("black.dump_to_file", dump_to_stderr)
261     def test_pep_572(self) -> None:
262         source, expected = read_data("pep_572")
263         actual = fs(source)
264         self.assertFormatEqual(expected, actual)
265         black.assert_stable(source, actual, DEFAULT_MODE)
266         if sys.version_info >= (3, 8):
267             black.assert_equivalent(source, actual)
268
269     @patch("black.dump_to_file", dump_to_stderr)
270     def test_pep_572_remove_parens(self) -> None:
271         source, expected = read_data("pep_572_remove_parens")
272         actual = fs(source)
273         self.assertFormatEqual(expected, actual)
274         black.assert_stable(source, actual, DEFAULT_MODE)
275         if sys.version_info >= (3, 8):
276             black.assert_equivalent(source, actual)
277
278     @patch("black.dump_to_file", dump_to_stderr)
279     def test_pep_572_do_not_remove_parens(self) -> None:
280         source, expected = read_data("pep_572_do_not_remove_parens")
281         # the AST safety checks will fail, but that's expected, just make sure no
282         # parentheses are touched
283         actual = black.format_str(source, mode=DEFAULT_MODE)
284         self.assertFormatEqual(expected, actual)
285
286     def test_pep_572_version_detection(self) -> None:
287         source, _ = read_data("pep_572")
288         root = black.lib2to3_parse(source)
289         features = black.get_features_used(root)
290         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
291         versions = black.detect_target_versions(root)
292         self.assertIn(black.TargetVersion.PY38, versions)
293
294     def test_expression_ff(self) -> None:
295         source, expected = read_data("expression")
296         tmp_file = Path(black.dump_to_file(source))
297         try:
298             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
299             with open(tmp_file, encoding="utf8") as f:
300                 actual = f.read()
301         finally:
302             os.unlink(tmp_file)
303         self.assertFormatEqual(expected, actual)
304         with patch("black.dump_to_file", dump_to_stderr):
305             black.assert_equivalent(source, actual)
306             black.assert_stable(source, actual, DEFAULT_MODE)
307
308     def test_expression_diff(self) -> None:
309         source, _ = read_data("expression.py")
310         config = THIS_DIR / "data" / "empty_pyproject.toml"
311         expected, _ = read_data("expression.diff")
312         tmp_file = Path(black.dump_to_file(source))
313         diff_header = re.compile(
314             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
315             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
316         )
317         try:
318             result = BlackRunner().invoke(
319                 black.main, ["--diff", str(tmp_file), f"--config={config}"]
320             )
321             self.assertEqual(result.exit_code, 0)
322         finally:
323             os.unlink(tmp_file)
324         actual = result.output
325         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
326         if expected != actual:
327             dump = black.dump_to_file(actual)
328             msg = (
329                 "Expected diff isn't equal to the actual. If you made changes to"
330                 " expression.py and this is an anticipated difference, overwrite"
331                 f" tests/data/expression.diff with {dump}"
332             )
333             self.assertEqual(expected, actual, msg)
334
335     def test_expression_diff_with_color(self) -> None:
336         source, _ = read_data("expression.py")
337         config = THIS_DIR / "data" / "empty_pyproject.toml"
338         expected, _ = read_data("expression.diff")
339         tmp_file = Path(black.dump_to_file(source))
340         try:
341             result = BlackRunner().invoke(
342                 black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
343             )
344         finally:
345             os.unlink(tmp_file)
346         actual = result.output
347         # We check the contents of the diff in `test_expression_diff`. All
348         # we need to check here is that color codes exist in the result.
349         self.assertIn("\033[1;37m", actual)
350         self.assertIn("\033[36m", actual)
351         self.assertIn("\033[32m", actual)
352         self.assertIn("\033[31m", actual)
353         self.assertIn("\033[0m", actual)
354
355     @patch("black.dump_to_file", dump_to_stderr)
356     def test_pep_570(self) -> None:
357         source, expected = read_data("pep_570")
358         actual = fs(source)
359         self.assertFormatEqual(expected, actual)
360         black.assert_stable(source, actual, DEFAULT_MODE)
361         if sys.version_info >= (3, 8):
362             black.assert_equivalent(source, actual)
363
364     def test_detect_pos_only_arguments(self) -> None:
365         source, _ = read_data("pep_570")
366         root = black.lib2to3_parse(source)
367         features = black.get_features_used(root)
368         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
369         versions = black.detect_target_versions(root)
370         self.assertIn(black.TargetVersion.PY38, versions)
371
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_string_quotes(self) -> None:
374         source, expected = read_data("string_quotes")
375         mode = black.Mode(experimental_string_processing=True)
376         actual = fs(source, mode=mode)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, mode)
380         mode = replace(mode, string_normalization=False)
381         not_normalized = fs(source, mode=mode)
382         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
383         black.assert_equivalent(source, not_normalized)
384         black.assert_stable(source, not_normalized, mode=mode)
385
386     @patch("black.dump_to_file", dump_to_stderr)
387     def test_docstring_no_string_normalization(self) -> None:
388         """Like test_docstring but with string normalization off."""
389         source, expected = read_data("docstring_no_string_normalization")
390         mode = replace(DEFAULT_MODE, string_normalization=False)
391         actual = fs(source, mode=mode)
392         self.assertFormatEqual(expected, actual)
393         black.assert_equivalent(source, actual)
394         black.assert_stable(source, actual, mode)
395
396     def test_long_strings_flag_disabled(self) -> None:
397         """Tests for turning off the string processing logic."""
398         source, expected = read_data("long_strings_flag_disabled")
399         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
400         actual = fs(source, mode=mode)
401         self.assertFormatEqual(expected, actual)
402         black.assert_stable(expected, actual, mode)
403
404     @patch("black.dump_to_file", dump_to_stderr)
405     def test_numeric_literals(self) -> None:
406         source, expected = read_data("numeric_literals")
407         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
408         actual = fs(source, mode=mode)
409         self.assertFormatEqual(expected, actual)
410         black.assert_equivalent(source, actual)
411         black.assert_stable(source, actual, mode)
412
413     @patch("black.dump_to_file", dump_to_stderr)
414     def test_numeric_literals_ignoring_underscores(self) -> None:
415         source, expected = read_data("numeric_literals_skip_underscores")
416         mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
417         actual = fs(source, mode=mode)
418         self.assertFormatEqual(expected, actual)
419         black.assert_equivalent(source, actual)
420         black.assert_stable(source, actual, mode)
421
422     def test_skip_magic_trailing_comma(self) -> None:
423         source, _ = read_data("expression.py")
424         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
425         tmp_file = Path(black.dump_to_file(source))
426         diff_header = re.compile(
427             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
428             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
429         )
430         try:
431             result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
432             self.assertEqual(result.exit_code, 0)
433         finally:
434             os.unlink(tmp_file)
435         actual = result.output
436         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
437         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
438         if expected != actual:
439             dump = black.dump_to_file(actual)
440             msg = (
441                 "Expected diff isn't equal to the actual. If you made changes to"
442                 " expression.py and this is an anticipated difference, overwrite"
443                 f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}"
444             )
445             self.assertEqual(expected, actual, msg)
446
447     @pytest.mark.no_python2
448     def test_python2_should_fail_without_optional_install(self) -> None:
449         if sys.version_info < (3, 8):
450             self.skipTest(
451                 "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
452                 " able to parse Python 2 syntax without explicitly specifying the"
453                 " python2 extra"
454             )
455
456         source = "x = 1234l"
457         tmp_file = Path(black.dump_to_file(source))
458         try:
459             runner = BlackRunner()
460             result = runner.invoke(black.main, [str(tmp_file)])
461             self.assertEqual(result.exit_code, 123)
462         finally:
463             os.unlink(tmp_file)
464         actual = (
465             result.stderr_bytes.decode()
466             .replace("\n", "")
467             .replace("\\n", "")
468             .replace("\\r", "")
469             .replace("\r", "")
470         )
471         msg = (
472             "The requested source code has invalid Python 3 syntax."
473             "If you are trying to format Python 2 files please reinstall Black"
474             " with the 'python2' extra: `python3 -m pip install black[python2]`."
475         )
476         self.assertIn(msg, actual)
477
478     @pytest.mark.python2
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_python2_print_function(self) -> None:
481         source, expected = read_data("python2_print_function")
482         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
483         actual = fs(source, mode=mode)
484         self.assertFormatEqual(expected, actual)
485         black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, mode)
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_stub(self) -> None:
490         mode = replace(DEFAULT_MODE, is_pyi=True)
491         source, expected = read_data("stub.pyi")
492         actual = fs(source, mode=mode)
493         self.assertFormatEqual(expected, actual)
494         black.assert_stable(source, actual, mode)
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_async_as_identifier(self) -> None:
498         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
499         source, expected = read_data("async_as_identifier")
500         actual = fs(source)
501         self.assertFormatEqual(expected, actual)
502         major, minor = sys.version_info[:2]
503         if major < 3 or (major <= 3 and minor < 7):
504             black.assert_equivalent(source, actual)
505         black.assert_stable(source, actual, DEFAULT_MODE)
506         # ensure black can parse this when the target is 3.6
507         self.invokeBlack([str(source_path), "--target-version", "py36"])
508         # but not on 3.7, because async/await is no longer an identifier
509         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
510
511     @patch("black.dump_to_file", dump_to_stderr)
512     def test_python37(self) -> None:
513         source_path = (THIS_DIR / "data" / "python37.py").resolve()
514         source, expected = read_data("python37")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         major, minor = sys.version_info[:2]
518         if major > 3 or (major == 3 and minor >= 7):
519             black.assert_equivalent(source, actual)
520         black.assert_stable(source, actual, DEFAULT_MODE)
521         # ensure black can parse this when the target is 3.7
522         self.invokeBlack([str(source_path), "--target-version", "py37"])
523         # but not on 3.6, because we use async as a reserved keyword
524         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
525
526     @patch("black.dump_to_file", dump_to_stderr)
527     def test_python38(self) -> None:
528         source, expected = read_data("python38")
529         actual = fs(source)
530         self.assertFormatEqual(expected, actual)
531         major, minor = sys.version_info[:2]
532         if major > 3 or (major == 3 and minor >= 8):
533             black.assert_equivalent(source, actual)
534         black.assert_stable(source, actual, DEFAULT_MODE)
535
536     @patch("black.dump_to_file", dump_to_stderr)
537     def test_python39(self) -> None:
538         source, expected = read_data("python39")
539         actual = fs(source)
540         self.assertFormatEqual(expected, actual)
541         major, minor = sys.version_info[:2]
542         if major > 3 or (major == 3 and minor >= 9):
543             black.assert_equivalent(source, actual)
544         black.assert_stable(source, actual, DEFAULT_MODE)
545
546     def test_tab_comment_indentation(self) -> None:
547         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
548         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
549         self.assertFormatEqual(contents_spc, fs(contents_spc))
550         self.assertFormatEqual(contents_spc, fs(contents_tab))
551
552         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
553         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
554         self.assertFormatEqual(contents_spc, fs(contents_spc))
555         self.assertFormatEqual(contents_spc, fs(contents_tab))
556
557         # mixed tabs and spaces (valid Python 2 code)
558         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
559         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
560         self.assertFormatEqual(contents_spc, fs(contents_spc))
561         self.assertFormatEqual(contents_spc, fs(contents_tab))
562
563         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
564         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
565         self.assertFormatEqual(contents_spc, fs(contents_spc))
566         self.assertFormatEqual(contents_spc, fs(contents_tab))
567
568     def test_report_verbose(self) -> None:
569         report = Report(verbose=True)
570         out_lines = []
571         err_lines = []
572
573         def out(msg: str, **kwargs: Any) -> None:
574             out_lines.append(msg)
575
576         def err(msg: str, **kwargs: Any) -> None:
577             err_lines.append(msg)
578
579         with patch("black.output._out", out), patch("black.output._err", err):
580             report.done(Path("f1"), black.Changed.NO)
581             self.assertEqual(len(out_lines), 1)
582             self.assertEqual(len(err_lines), 0)
583             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
584             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
585             self.assertEqual(report.return_code, 0)
586             report.done(Path("f2"), black.Changed.YES)
587             self.assertEqual(len(out_lines), 2)
588             self.assertEqual(len(err_lines), 0)
589             self.assertEqual(out_lines[-1], "reformatted f2")
590             self.assertEqual(
591                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
592             )
593             report.done(Path("f3"), black.Changed.CACHED)
594             self.assertEqual(len(out_lines), 3)
595             self.assertEqual(len(err_lines), 0)
596             self.assertEqual(
597                 out_lines[-1], "f3 wasn't modified on disk since last run."
598             )
599             self.assertEqual(
600                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
601             )
602             self.assertEqual(report.return_code, 0)
603             report.check = True
604             self.assertEqual(report.return_code, 1)
605             report.check = False
606             report.failed(Path("e1"), "boom")
607             self.assertEqual(len(out_lines), 3)
608             self.assertEqual(len(err_lines), 1)
609             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
610             self.assertEqual(
611                 unstyle(str(report)),
612                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
613                 " reformat.",
614             )
615             self.assertEqual(report.return_code, 123)
616             report.done(Path("f3"), black.Changed.YES)
617             self.assertEqual(len(out_lines), 4)
618             self.assertEqual(len(err_lines), 1)
619             self.assertEqual(out_lines[-1], "reformatted f3")
620             self.assertEqual(
621                 unstyle(str(report)),
622                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
623                 " reformat.",
624             )
625             self.assertEqual(report.return_code, 123)
626             report.failed(Path("e2"), "boom")
627             self.assertEqual(len(out_lines), 4)
628             self.assertEqual(len(err_lines), 2)
629             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
630             self.assertEqual(
631                 unstyle(str(report)),
632                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
633                 " reformat.",
634             )
635             self.assertEqual(report.return_code, 123)
636             report.path_ignored(Path("wat"), "no match")
637             self.assertEqual(len(out_lines), 5)
638             self.assertEqual(len(err_lines), 2)
639             self.assertEqual(out_lines[-1], "wat ignored: no match")
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
643                 " reformat.",
644             )
645             self.assertEqual(report.return_code, 123)
646             report.done(Path("f4"), black.Changed.NO)
647             self.assertEqual(len(out_lines), 6)
648             self.assertEqual(len(err_lines), 2)
649             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
650             self.assertEqual(
651                 unstyle(str(report)),
652                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
653                 " reformat.",
654             )
655             self.assertEqual(report.return_code, 123)
656             report.check = True
657             self.assertEqual(
658                 unstyle(str(report)),
659                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
660                 " would fail to reformat.",
661             )
662             report.check = False
663             report.diff = True
664             self.assertEqual(
665                 unstyle(str(report)),
666                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
667                 " would fail to reformat.",
668             )
669
670     def test_report_quiet(self) -> None:
671         report = Report(quiet=True)
672         out_lines = []
673         err_lines = []
674
675         def out(msg: str, **kwargs: Any) -> None:
676             out_lines.append(msg)
677
678         def err(msg: str, **kwargs: Any) -> None:
679             err_lines.append(msg)
680
681         with patch("black.output._out", out), patch("black.output._err", err):
682             report.done(Path("f1"), black.Changed.NO)
683             self.assertEqual(len(out_lines), 0)
684             self.assertEqual(len(err_lines), 0)
685             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
686             self.assertEqual(report.return_code, 0)
687             report.done(Path("f2"), black.Changed.YES)
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 0)
690             self.assertEqual(
691                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
692             )
693             report.done(Path("f3"), black.Changed.CACHED)
694             self.assertEqual(len(out_lines), 0)
695             self.assertEqual(len(err_lines), 0)
696             self.assertEqual(
697                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
698             )
699             self.assertEqual(report.return_code, 0)
700             report.check = True
701             self.assertEqual(report.return_code, 1)
702             report.check = False
703             report.failed(Path("e1"), "boom")
704             self.assertEqual(len(out_lines), 0)
705             self.assertEqual(len(err_lines), 1)
706             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
707             self.assertEqual(
708                 unstyle(str(report)),
709                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
710                 " reformat.",
711             )
712             self.assertEqual(report.return_code, 123)
713             report.done(Path("f3"), black.Changed.YES)
714             self.assertEqual(len(out_lines), 0)
715             self.assertEqual(len(err_lines), 1)
716             self.assertEqual(
717                 unstyle(str(report)),
718                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
719                 " reformat.",
720             )
721             self.assertEqual(report.return_code, 123)
722             report.failed(Path("e2"), "boom")
723             self.assertEqual(len(out_lines), 0)
724             self.assertEqual(len(err_lines), 2)
725             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
726             self.assertEqual(
727                 unstyle(str(report)),
728                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
729                 " reformat.",
730             )
731             self.assertEqual(report.return_code, 123)
732             report.path_ignored(Path("wat"), "no match")
733             self.assertEqual(len(out_lines), 0)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
738                 " reformat.",
739             )
740             self.assertEqual(report.return_code, 123)
741             report.done(Path("f4"), black.Changed.NO)
742             self.assertEqual(len(out_lines), 0)
743             self.assertEqual(len(err_lines), 2)
744             self.assertEqual(
745                 unstyle(str(report)),
746                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
747                 " reformat.",
748             )
749             self.assertEqual(report.return_code, 123)
750             report.check = True
751             self.assertEqual(
752                 unstyle(str(report)),
753                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
754                 " would fail to reformat.",
755             )
756             report.check = False
757             report.diff = True
758             self.assertEqual(
759                 unstyle(str(report)),
760                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
761                 " would fail to reformat.",
762             )
763
764     def test_report_normal(self) -> None:
765         report = black.Report()
766         out_lines = []
767         err_lines = []
768
769         def out(msg: str, **kwargs: Any) -> None:
770             out_lines.append(msg)
771
772         def err(msg: str, **kwargs: Any) -> None:
773             err_lines.append(msg)
774
775         with patch("black.output._out", out), patch("black.output._err", err):
776             report.done(Path("f1"), black.Changed.NO)
777             self.assertEqual(len(out_lines), 0)
778             self.assertEqual(len(err_lines), 0)
779             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
780             self.assertEqual(report.return_code, 0)
781             report.done(Path("f2"), black.Changed.YES)
782             self.assertEqual(len(out_lines), 1)
783             self.assertEqual(len(err_lines), 0)
784             self.assertEqual(out_lines[-1], "reformatted f2")
785             self.assertEqual(
786                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
787             )
788             report.done(Path("f3"), black.Changed.CACHED)
789             self.assertEqual(len(out_lines), 1)
790             self.assertEqual(len(err_lines), 0)
791             self.assertEqual(out_lines[-1], "reformatted f2")
792             self.assertEqual(
793                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
794             )
795             self.assertEqual(report.return_code, 0)
796             report.check = True
797             self.assertEqual(report.return_code, 1)
798             report.check = False
799             report.failed(Path("e1"), "boom")
800             self.assertEqual(len(out_lines), 1)
801             self.assertEqual(len(err_lines), 1)
802             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
803             self.assertEqual(
804                 unstyle(str(report)),
805                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
806                 " reformat.",
807             )
808             self.assertEqual(report.return_code, 123)
809             report.done(Path("f3"), black.Changed.YES)
810             self.assertEqual(len(out_lines), 2)
811             self.assertEqual(len(err_lines), 1)
812             self.assertEqual(out_lines[-1], "reformatted f3")
813             self.assertEqual(
814                 unstyle(str(report)),
815                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
816                 " reformat.",
817             )
818             self.assertEqual(report.return_code, 123)
819             report.failed(Path("e2"), "boom")
820             self.assertEqual(len(out_lines), 2)
821             self.assertEqual(len(err_lines), 2)
822             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
823             self.assertEqual(
824                 unstyle(str(report)),
825                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
826                 " reformat.",
827             )
828             self.assertEqual(report.return_code, 123)
829             report.path_ignored(Path("wat"), "no match")
830             self.assertEqual(len(out_lines), 2)
831             self.assertEqual(len(err_lines), 2)
832             self.assertEqual(
833                 unstyle(str(report)),
834                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
835                 " reformat.",
836             )
837             self.assertEqual(report.return_code, 123)
838             report.done(Path("f4"), black.Changed.NO)
839             self.assertEqual(len(out_lines), 2)
840             self.assertEqual(len(err_lines), 2)
841             self.assertEqual(
842                 unstyle(str(report)),
843                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
844                 " reformat.",
845             )
846             self.assertEqual(report.return_code, 123)
847             report.check = True
848             self.assertEqual(
849                 unstyle(str(report)),
850                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
851                 " would fail to reformat.",
852             )
853             report.check = False
854             report.diff = True
855             self.assertEqual(
856                 unstyle(str(report)),
857                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
858                 " would fail to reformat.",
859             )
860
861     def test_lib2to3_parse(self) -> None:
862         with self.assertRaises(black.InvalidInput):
863             black.lib2to3_parse("invalid syntax")
864
865         straddling = "x + y"
866         black.lib2to3_parse(straddling)
867         black.lib2to3_parse(straddling, {TargetVersion.PY27})
868         black.lib2to3_parse(straddling, {TargetVersion.PY36})
869         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
870
871         py2_only = "print x"
872         black.lib2to3_parse(py2_only)
873         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
874         with self.assertRaises(black.InvalidInput):
875             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
876         with self.assertRaises(black.InvalidInput):
877             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
878
879         py3_only = "exec(x, end=y)"
880         black.lib2to3_parse(py3_only)
881         with self.assertRaises(black.InvalidInput):
882             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
883         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
884         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
885
886     def test_get_features_used_decorator(self) -> None:
887         # Test the feature detection of new decorator syntax
888         # since this makes some test cases of test_get_features_used()
889         # fails if it fails, this is tested first so that a useful case
890         # is identified
891         simples, relaxed = read_data("decorators")
892         # skip explanation comments at the top of the file
893         for simple_test in simples.split("##")[1:]:
894             node = black.lib2to3_parse(simple_test)
895             decorator = str(node.children[0].children[0]).strip()
896             self.assertNotIn(
897                 Feature.RELAXED_DECORATORS,
898                 black.get_features_used(node),
899                 msg=(
900                     f"decorator '{decorator}' follows python<=3.8 syntax"
901                     "but is detected as 3.9+"
902                     # f"The full node is\n{node!r}"
903                 ),
904             )
905         # skip the '# output' comment at the top of the output part
906         for relaxed_test in relaxed.split("##")[1:]:
907             node = black.lib2to3_parse(relaxed_test)
908             decorator = str(node.children[0].children[0]).strip()
909             self.assertIn(
910                 Feature.RELAXED_DECORATORS,
911                 black.get_features_used(node),
912                 msg=(
913                     f"decorator '{decorator}' uses python3.9+ syntax"
914                     "but is detected as python<=3.8"
915                     # f"The full node is\n{node!r}"
916                 ),
917             )
918
919     def test_get_features_used(self) -> None:
920         node = black.lib2to3_parse("def f(*, arg): ...\n")
921         self.assertEqual(black.get_features_used(node), set())
922         node = black.lib2to3_parse("def f(*, arg,): ...\n")
923         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
924         node = black.lib2to3_parse("f(*arg,)\n")
925         self.assertEqual(
926             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
927         )
928         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
929         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
930         node = black.lib2to3_parse("123_456\n")
931         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
932         node = black.lib2to3_parse("123456\n")
933         self.assertEqual(black.get_features_used(node), set())
934         source, expected = read_data("function")
935         node = black.lib2to3_parse(source)
936         expected_features = {
937             Feature.TRAILING_COMMA_IN_CALL,
938             Feature.TRAILING_COMMA_IN_DEF,
939             Feature.F_STRINGS,
940         }
941         self.assertEqual(black.get_features_used(node), expected_features)
942         node = black.lib2to3_parse(expected)
943         self.assertEqual(black.get_features_used(node), expected_features)
944         source, expected = read_data("expression")
945         node = black.lib2to3_parse(source)
946         self.assertEqual(black.get_features_used(node), set())
947         node = black.lib2to3_parse(expected)
948         self.assertEqual(black.get_features_used(node), set())
949
950     def test_get_future_imports(self) -> None:
951         node = black.lib2to3_parse("\n")
952         self.assertEqual(set(), black.get_future_imports(node))
953         node = black.lib2to3_parse("from __future__ import black\n")
954         self.assertEqual({"black"}, black.get_future_imports(node))
955         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
956         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
957         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
958         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
959         node = black.lib2to3_parse(
960             "from __future__ import multiple\nfrom __future__ import imports\n"
961         )
962         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
963         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
964         self.assertEqual({"black"}, black.get_future_imports(node))
965         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
966         self.assertEqual({"black"}, black.get_future_imports(node))
967         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
968         self.assertEqual(set(), black.get_future_imports(node))
969         node = black.lib2to3_parse("from some.module import black\n")
970         self.assertEqual(set(), black.get_future_imports(node))
971         node = black.lib2to3_parse(
972             "from __future__ import unicode_literals as _unicode_literals"
973         )
974         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
975         node = black.lib2to3_parse(
976             "from __future__ import unicode_literals as _lol, print"
977         )
978         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
979
980     def test_debug_visitor(self) -> None:
981         source, _ = read_data("debug_visitor.py")
982         expected, _ = read_data("debug_visitor.out")
983         out_lines = []
984         err_lines = []
985
986         def out(msg: str, **kwargs: Any) -> None:
987             out_lines.append(msg)
988
989         def err(msg: str, **kwargs: Any) -> None:
990             err_lines.append(msg)
991
992         with patch("black.debug.out", out):
993             DebugVisitor.show(source)
994         actual = "\n".join(out_lines) + "\n"
995         log_name = ""
996         if expected != actual:
997             log_name = black.dump_to_file(*out_lines)
998         self.assertEqual(
999             expected,
1000             actual,
1001             f"AST print out is different. Actual version dumped to {log_name}",
1002         )
1003
1004     def test_format_file_contents(self) -> None:
1005         empty = ""
1006         mode = DEFAULT_MODE
1007         with self.assertRaises(black.NothingChanged):
1008             black.format_file_contents(empty, mode=mode, fast=False)
1009         just_nl = "\n"
1010         with self.assertRaises(black.NothingChanged):
1011             black.format_file_contents(just_nl, mode=mode, fast=False)
1012         same = "j = [1, 2, 3]\n"
1013         with self.assertRaises(black.NothingChanged):
1014             black.format_file_contents(same, mode=mode, fast=False)
1015         different = "j = [1,2,3]"
1016         expected = same
1017         actual = black.format_file_contents(different, mode=mode, fast=False)
1018         self.assertEqual(expected, actual)
1019         invalid = "return if you can"
1020         with self.assertRaises(black.InvalidInput) as e:
1021             black.format_file_contents(invalid, mode=mode, fast=False)
1022         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1023
1024     def test_endmarker(self) -> None:
1025         n = black.lib2to3_parse("\n")
1026         self.assertEqual(n.type, black.syms.file_input)
1027         self.assertEqual(len(n.children), 1)
1028         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1029
1030     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1031     def test_assertFormatEqual(self) -> None:
1032         out_lines = []
1033         err_lines = []
1034
1035         def out(msg: str, **kwargs: Any) -> None:
1036             out_lines.append(msg)
1037
1038         def err(msg: str, **kwargs: Any) -> None:
1039             err_lines.append(msg)
1040
1041         with patch("black.output._out", out), patch("black.output._err", err):
1042             with self.assertRaises(AssertionError):
1043                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1044
1045         out_str = "".join(out_lines)
1046         self.assertTrue("Expected tree:" in out_str)
1047         self.assertTrue("Actual tree:" in out_str)
1048         self.assertEqual("".join(err_lines), "")
1049
1050     def test_cache_broken_file(self) -> None:
1051         mode = DEFAULT_MODE
1052         with cache_dir() as workspace:
1053             cache_file = get_cache_file(mode)
1054             with cache_file.open("w") as fobj:
1055                 fobj.write("this is not a pickle")
1056             self.assertEqual(black.read_cache(mode), {})
1057             src = (workspace / "test.py").resolve()
1058             with src.open("w") as fobj:
1059                 fobj.write("print('hello')")
1060             self.invokeBlack([str(src)])
1061             cache = black.read_cache(mode)
1062             self.assertIn(str(src), cache)
1063
1064     def test_cache_single_file_already_cached(self) -> None:
1065         mode = DEFAULT_MODE
1066         with cache_dir() as workspace:
1067             src = (workspace / "test.py").resolve()
1068             with src.open("w") as fobj:
1069                 fobj.write("print('hello')")
1070             black.write_cache({}, [src], mode)
1071             self.invokeBlack([str(src)])
1072             with src.open("r") as fobj:
1073                 self.assertEqual(fobj.read(), "print('hello')")
1074
1075     @event_loop()
1076     def test_cache_multiple_files(self) -> None:
1077         mode = DEFAULT_MODE
1078         with cache_dir() as workspace, patch(
1079             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1080         ):
1081             one = (workspace / "one.py").resolve()
1082             with one.open("w") as fobj:
1083                 fobj.write("print('hello')")
1084             two = (workspace / "two.py").resolve()
1085             with two.open("w") as fobj:
1086                 fobj.write("print('hello')")
1087             black.write_cache({}, [one], mode)
1088             self.invokeBlack([str(workspace)])
1089             with one.open("r") as fobj:
1090                 self.assertEqual(fobj.read(), "print('hello')")
1091             with two.open("r") as fobj:
1092                 self.assertEqual(fobj.read(), 'print("hello")\n')
1093             cache = black.read_cache(mode)
1094             self.assertIn(str(one), cache)
1095             self.assertIn(str(two), cache)
1096
1097     def test_no_cache_when_writeback_diff(self) -> None:
1098         mode = DEFAULT_MODE
1099         with cache_dir() as workspace:
1100             src = (workspace / "test.py").resolve()
1101             with src.open("w") as fobj:
1102                 fobj.write("print('hello')")
1103             with patch("black.read_cache") as read_cache, patch(
1104                 "black.write_cache"
1105             ) as write_cache:
1106                 self.invokeBlack([str(src), "--diff"])
1107                 cache_file = get_cache_file(mode)
1108                 self.assertFalse(cache_file.exists())
1109                 write_cache.assert_not_called()
1110                 read_cache.assert_not_called()
1111
1112     def test_no_cache_when_writeback_color_diff(self) -> None:
1113         mode = DEFAULT_MODE
1114         with cache_dir() as workspace:
1115             src = (workspace / "test.py").resolve()
1116             with src.open("w") as fobj:
1117                 fobj.write("print('hello')")
1118             with patch("black.read_cache") as read_cache, patch(
1119                 "black.write_cache"
1120             ) as write_cache:
1121                 self.invokeBlack([str(src), "--diff", "--color"])
1122                 cache_file = get_cache_file(mode)
1123                 self.assertFalse(cache_file.exists())
1124                 write_cache.assert_not_called()
1125                 read_cache.assert_not_called()
1126
1127     @event_loop()
1128     def test_output_locking_when_writeback_diff(self) -> None:
1129         with cache_dir() as workspace:
1130             for tag in range(0, 4):
1131                 src = (workspace / f"test{tag}.py").resolve()
1132                 with src.open("w") as fobj:
1133                     fobj.write("print('hello')")
1134             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1135                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1136                 # this isn't quite doing what we want, but if it _isn't_
1137                 # called then we cannot be using the lock it provides
1138                 mgr.assert_called()
1139
1140     @event_loop()
1141     def test_output_locking_when_writeback_color_diff(self) -> None:
1142         with cache_dir() as workspace:
1143             for tag in range(0, 4):
1144                 src = (workspace / f"test{tag}.py").resolve()
1145                 with src.open("w") as fobj:
1146                     fobj.write("print('hello')")
1147             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1148                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1149                 # this isn't quite doing what we want, but if it _isn't_
1150                 # called then we cannot be using the lock it provides
1151                 mgr.assert_called()
1152
1153     def test_no_cache_when_stdin(self) -> None:
1154         mode = DEFAULT_MODE
1155         with cache_dir():
1156             result = CliRunner().invoke(
1157                 black.main, ["-"], input=BytesIO(b"print('hello')")
1158             )
1159             self.assertEqual(result.exit_code, 0)
1160             cache_file = get_cache_file(mode)
1161             self.assertFalse(cache_file.exists())
1162
1163     def test_read_cache_no_cachefile(self) -> None:
1164         mode = DEFAULT_MODE
1165         with cache_dir():
1166             self.assertEqual(black.read_cache(mode), {})
1167
1168     def test_write_cache_read_cache(self) -> None:
1169         mode = DEFAULT_MODE
1170         with cache_dir() as workspace:
1171             src = (workspace / "test.py").resolve()
1172             src.touch()
1173             black.write_cache({}, [src], mode)
1174             cache = black.read_cache(mode)
1175             self.assertIn(str(src), cache)
1176             self.assertEqual(cache[str(src)], black.get_cache_info(src))
1177
1178     def test_filter_cached(self) -> None:
1179         with TemporaryDirectory() as workspace:
1180             path = Path(workspace)
1181             uncached = (path / "uncached").resolve()
1182             cached = (path / "cached").resolve()
1183             cached_but_changed = (path / "changed").resolve()
1184             uncached.touch()
1185             cached.touch()
1186             cached_but_changed.touch()
1187             cache = {
1188                 str(cached): black.get_cache_info(cached),
1189                 str(cached_but_changed): (0.0, 0),
1190             }
1191             todo, done = black.filter_cached(
1192                 cache, {uncached, cached, cached_but_changed}
1193             )
1194             self.assertEqual(todo, {uncached, cached_but_changed})
1195             self.assertEqual(done, {cached})
1196
1197     def test_write_cache_creates_directory_if_needed(self) -> None:
1198         mode = DEFAULT_MODE
1199         with cache_dir(exists=False) as workspace:
1200             self.assertFalse(workspace.exists())
1201             black.write_cache({}, [], mode)
1202             self.assertTrue(workspace.exists())
1203
1204     @event_loop()
1205     def test_failed_formatting_does_not_get_cached(self) -> None:
1206         mode = DEFAULT_MODE
1207         with cache_dir() as workspace, patch(
1208             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1209         ):
1210             failing = (workspace / "failing.py").resolve()
1211             with failing.open("w") as fobj:
1212                 fobj.write("not actually python")
1213             clean = (workspace / "clean.py").resolve()
1214             with clean.open("w") as fobj:
1215                 fobj.write('print("hello")\n')
1216             self.invokeBlack([str(workspace)], exit_code=123)
1217             cache = black.read_cache(mode)
1218             self.assertNotIn(str(failing), cache)
1219             self.assertIn(str(clean), cache)
1220
1221     def test_write_cache_write_fail(self) -> None:
1222         mode = DEFAULT_MODE
1223         with cache_dir(), patch.object(Path, "open") as mock:
1224             mock.side_effect = OSError
1225             black.write_cache({}, [], mode)
1226
1227     @event_loop()
1228     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1229     def test_works_in_mono_process_only_environment(self) -> None:
1230         with cache_dir() as workspace:
1231             for f in [
1232                 (workspace / "one.py").resolve(),
1233                 (workspace / "two.py").resolve(),
1234             ]:
1235                 f.write_text('print("hello")\n')
1236             self.invokeBlack([str(workspace)])
1237
1238     @event_loop()
1239     def test_check_diff_use_together(self) -> None:
1240         with cache_dir():
1241             # Files which will be reformatted.
1242             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1243             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1244             # Files which will not be reformatted.
1245             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1246             self.invokeBlack([str(src2), "--diff", "--check"])
1247             # Multi file command.
1248             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1249
1250     def test_no_files(self) -> None:
1251         with cache_dir():
1252             # Without an argument, black exits with error code 0.
1253             self.invokeBlack([])
1254
1255     def test_broken_symlink(self) -> None:
1256         with cache_dir() as workspace:
1257             symlink = workspace / "broken_link.py"
1258             try:
1259                 symlink.symlink_to("nonexistent.py")
1260             except OSError as e:
1261                 self.skipTest(f"Can't create symlinks: {e}")
1262             self.invokeBlack([str(workspace.resolve())])
1263
1264     def test_read_cache_line_lengths(self) -> None:
1265         mode = DEFAULT_MODE
1266         short_mode = replace(DEFAULT_MODE, line_length=1)
1267         with cache_dir() as workspace:
1268             path = (workspace / "file.py").resolve()
1269             path.touch()
1270             black.write_cache({}, [path], mode)
1271             one = black.read_cache(mode)
1272             self.assertIn(str(path), one)
1273             two = black.read_cache(short_mode)
1274             self.assertNotIn(str(path), two)
1275
1276     def test_single_file_force_pyi(self) -> None:
1277         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1278         contents, expected = read_data("force_pyi")
1279         with cache_dir() as workspace:
1280             path = (workspace / "file.py").resolve()
1281             with open(path, "w") as fh:
1282                 fh.write(contents)
1283             self.invokeBlack([str(path), "--pyi"])
1284             with open(path, "r") as fh:
1285                 actual = fh.read()
1286             # verify cache with --pyi is separate
1287             pyi_cache = black.read_cache(pyi_mode)
1288             self.assertIn(str(path), pyi_cache)
1289             normal_cache = black.read_cache(DEFAULT_MODE)
1290             self.assertNotIn(str(path), normal_cache)
1291         self.assertFormatEqual(expected, actual)
1292         black.assert_equivalent(contents, actual)
1293         black.assert_stable(contents, actual, pyi_mode)
1294
1295     @event_loop()
1296     def test_multi_file_force_pyi(self) -> None:
1297         reg_mode = DEFAULT_MODE
1298         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1299         contents, expected = read_data("force_pyi")
1300         with cache_dir() as workspace:
1301             paths = [
1302                 (workspace / "file1.py").resolve(),
1303                 (workspace / "file2.py").resolve(),
1304             ]
1305             for path in paths:
1306                 with open(path, "w") as fh:
1307                     fh.write(contents)
1308             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1309             for path in paths:
1310                 with open(path, "r") as fh:
1311                     actual = fh.read()
1312                 self.assertEqual(actual, expected)
1313             # verify cache with --pyi is separate
1314             pyi_cache = black.read_cache(pyi_mode)
1315             normal_cache = black.read_cache(reg_mode)
1316             for path in paths:
1317                 self.assertIn(str(path), pyi_cache)
1318                 self.assertNotIn(str(path), normal_cache)
1319
1320     def test_pipe_force_pyi(self) -> None:
1321         source, expected = read_data("force_pyi")
1322         result = CliRunner().invoke(
1323             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1324         )
1325         self.assertEqual(result.exit_code, 0)
1326         actual = result.output
1327         self.assertFormatEqual(actual, expected)
1328
1329     def test_single_file_force_py36(self) -> None:
1330         reg_mode = DEFAULT_MODE
1331         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1332         source, expected = read_data("force_py36")
1333         with cache_dir() as workspace:
1334             path = (workspace / "file.py").resolve()
1335             with open(path, "w") as fh:
1336                 fh.write(source)
1337             self.invokeBlack([str(path), *PY36_ARGS])
1338             with open(path, "r") as fh:
1339                 actual = fh.read()
1340             # verify cache with --target-version is separate
1341             py36_cache = black.read_cache(py36_mode)
1342             self.assertIn(str(path), py36_cache)
1343             normal_cache = black.read_cache(reg_mode)
1344             self.assertNotIn(str(path), normal_cache)
1345         self.assertEqual(actual, expected)
1346
1347     @event_loop()
1348     def test_multi_file_force_py36(self) -> None:
1349         reg_mode = DEFAULT_MODE
1350         py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
1351         source, expected = read_data("force_py36")
1352         with cache_dir() as workspace:
1353             paths = [
1354                 (workspace / "file1.py").resolve(),
1355                 (workspace / "file2.py").resolve(),
1356             ]
1357             for path in paths:
1358                 with open(path, "w") as fh:
1359                     fh.write(source)
1360             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1361             for path in paths:
1362                 with open(path, "r") as fh:
1363                     actual = fh.read()
1364                 self.assertEqual(actual, expected)
1365             # verify cache with --target-version is separate
1366             pyi_cache = black.read_cache(py36_mode)
1367             normal_cache = black.read_cache(reg_mode)
1368             for path in paths:
1369                 self.assertIn(str(path), pyi_cache)
1370                 self.assertNotIn(str(path), normal_cache)
1371
1372     def test_pipe_force_py36(self) -> None:
1373         source, expected = read_data("force_py36")
1374         result = CliRunner().invoke(
1375             black.main,
1376             ["-", "-q", "--target-version=py36"],
1377             input=BytesIO(source.encode("utf8")),
1378         )
1379         self.assertEqual(result.exit_code, 0)
1380         actual = result.output
1381         self.assertFormatEqual(actual, expected)
1382
1383     def test_include_exclude(self) -> None:
1384         path = THIS_DIR / "data" / "include_exclude_tests"
1385         include = re.compile(r"\.pyi?$")
1386         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1387         report = black.Report()
1388         gitignore = PathSpec.from_lines("gitwildmatch", [])
1389         sources: List[Path] = []
1390         expected = [
1391             Path(path / "b/dont_exclude/a.py"),
1392             Path(path / "b/dont_exclude/a.pyi"),
1393         ]
1394         this_abs = THIS_DIR.resolve()
1395         sources.extend(
1396             black.gen_python_files(
1397                 path.iterdir(),
1398                 this_abs,
1399                 include,
1400                 exclude,
1401                 None,
1402                 None,
1403                 report,
1404                 gitignore,
1405             )
1406         )
1407         self.assertEqual(sorted(expected), sorted(sources))
1408
1409     def test_gitignore_used_as_default(self) -> None:
1410         path = Path(THIS_DIR / "data" / "include_exclude_tests")
1411         include = re.compile(r"\.pyi?$")
1412         extend_exclude = re.compile(r"/exclude/")
1413         src = str(path / "b/")
1414         report = black.Report()
1415         expected: List[Path] = [
1416             path / "b/.definitely_exclude/a.py",
1417             path / "b/.definitely_exclude/a.pyi",
1418         ]
1419         sources = list(
1420             black.get_sources(
1421                 ctx=FakeContext(),
1422                 src=(src,),
1423                 quiet=True,
1424                 verbose=False,
1425                 include=include,
1426                 exclude=None,
1427                 extend_exclude=extend_exclude,
1428                 force_exclude=None,
1429                 report=report,
1430                 stdin_filename=None,
1431             )
1432         )
1433         self.assertEqual(sorted(expected), sorted(sources))
1434
1435     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1436     def test_exclude_for_issue_1572(self) -> None:
1437         # Exclude shouldn't touch files that were explicitly given to Black through the
1438         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1439         # https://github.com/psf/black/issues/1572
1440         path = THIS_DIR / "data" / "include_exclude_tests"
1441         include = ""
1442         exclude = r"/exclude/|a\.py"
1443         src = str(path / "b/exclude/a.py")
1444         report = black.Report()
1445         expected = [Path(path / "b/exclude/a.py")]
1446         sources = list(
1447             black.get_sources(
1448                 ctx=FakeContext(),
1449                 src=(src,),
1450                 quiet=True,
1451                 verbose=False,
1452                 include=re.compile(include),
1453                 exclude=re.compile(exclude),
1454                 extend_exclude=None,
1455                 force_exclude=None,
1456                 report=report,
1457                 stdin_filename=None,
1458             )
1459         )
1460         self.assertEqual(sorted(expected), sorted(sources))
1461
1462     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1463     def test_get_sources_with_stdin(self) -> None:
1464         include = ""
1465         exclude = r"/exclude/|a\.py"
1466         src = "-"
1467         report = black.Report()
1468         expected = [Path("-")]
1469         sources = list(
1470             black.get_sources(
1471                 ctx=FakeContext(),
1472                 src=(src,),
1473                 quiet=True,
1474                 verbose=False,
1475                 include=re.compile(include),
1476                 exclude=re.compile(exclude),
1477                 extend_exclude=None,
1478                 force_exclude=None,
1479                 report=report,
1480                 stdin_filename=None,
1481             )
1482         )
1483         self.assertEqual(sorted(expected), sorted(sources))
1484
1485     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1486     def test_get_sources_with_stdin_filename(self) -> None:
1487         include = ""
1488         exclude = r"/exclude/|a\.py"
1489         src = "-"
1490         report = black.Report()
1491         stdin_filename = str(THIS_DIR / "data/collections.py")
1492         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1493         sources = list(
1494             black.get_sources(
1495                 ctx=FakeContext(),
1496                 src=(src,),
1497                 quiet=True,
1498                 verbose=False,
1499                 include=re.compile(include),
1500                 exclude=re.compile(exclude),
1501                 extend_exclude=None,
1502                 force_exclude=None,
1503                 report=report,
1504                 stdin_filename=stdin_filename,
1505             )
1506         )
1507         self.assertEqual(sorted(expected), sorted(sources))
1508
1509     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1510     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
1511         # Exclude shouldn't exclude stdin_filename since it is mimicking the
1512         # file being passed directly. This is the same as
1513         # test_exclude_for_issue_1572
1514         path = THIS_DIR / "data" / "include_exclude_tests"
1515         include = ""
1516         exclude = r"/exclude/|a\.py"
1517         src = "-"
1518         report = black.Report()
1519         stdin_filename = str(path / "b/exclude/a.py")
1520         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1521         sources = list(
1522             black.get_sources(
1523                 ctx=FakeContext(),
1524                 src=(src,),
1525                 quiet=True,
1526                 verbose=False,
1527                 include=re.compile(include),
1528                 exclude=re.compile(exclude),
1529                 extend_exclude=None,
1530                 force_exclude=None,
1531                 report=report,
1532                 stdin_filename=stdin_filename,
1533             )
1534         )
1535         self.assertEqual(sorted(expected), sorted(sources))
1536
1537     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1538     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
1539         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
1540         # file being passed directly. This is the same as
1541         # test_exclude_for_issue_1572
1542         path = THIS_DIR / "data" / "include_exclude_tests"
1543         include = ""
1544         extend_exclude = r"/exclude/|a\.py"
1545         src = "-"
1546         report = black.Report()
1547         stdin_filename = str(path / "b/exclude/a.py")
1548         expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
1549         sources = list(
1550             black.get_sources(
1551                 ctx=FakeContext(),
1552                 src=(src,),
1553                 quiet=True,
1554                 verbose=False,
1555                 include=re.compile(include),
1556                 exclude=re.compile(""),
1557                 extend_exclude=re.compile(extend_exclude),
1558                 force_exclude=None,
1559                 report=report,
1560                 stdin_filename=stdin_filename,
1561             )
1562         )
1563         self.assertEqual(sorted(expected), sorted(sources))
1564
1565     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1566     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
1567         # Force exclude should exclude the file when passing it through
1568         # stdin_filename
1569         path = THIS_DIR / "data" / "include_exclude_tests"
1570         include = ""
1571         force_exclude = r"/exclude/|a\.py"
1572         src = "-"
1573         report = black.Report()
1574         stdin_filename = str(path / "b/exclude/a.py")
1575         sources = list(
1576             black.get_sources(
1577                 ctx=FakeContext(),
1578                 src=(src,),
1579                 quiet=True,
1580                 verbose=False,
1581                 include=re.compile(include),
1582                 exclude=re.compile(""),
1583                 extend_exclude=None,
1584                 force_exclude=re.compile(force_exclude),
1585                 report=report,
1586                 stdin_filename=stdin_filename,
1587             )
1588         )
1589         self.assertEqual([], sorted(sources))
1590
1591     def test_reformat_one_with_stdin(self) -> None:
1592         with patch(
1593             "black.format_stdin_to_stdout",
1594             return_value=lambda *args, **kwargs: black.Changed.YES,
1595         ) as fsts:
1596             report = MagicMock()
1597             path = Path("-")
1598             black.reformat_one(
1599                 path,
1600                 fast=True,
1601                 write_back=black.WriteBack.YES,
1602                 mode=DEFAULT_MODE,
1603                 report=report,
1604             )
1605             fsts.assert_called_once()
1606             report.done.assert_called_with(path, black.Changed.YES)
1607
1608     def test_reformat_one_with_stdin_filename(self) -> None:
1609         with patch(
1610             "black.format_stdin_to_stdout",
1611             return_value=lambda *args, **kwargs: black.Changed.YES,
1612         ) as fsts:
1613             report = MagicMock()
1614             p = "foo.py"
1615             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1616             expected = Path(p)
1617             black.reformat_one(
1618                 path,
1619                 fast=True,
1620                 write_back=black.WriteBack.YES,
1621                 mode=DEFAULT_MODE,
1622                 report=report,
1623             )
1624             fsts.assert_called_once_with(
1625                 fast=True, write_back=black.WriteBack.YES, mode=DEFAULT_MODE
1626             )
1627             # __BLACK_STDIN_FILENAME__ should have been stripped
1628             report.done.assert_called_with(expected, black.Changed.YES)
1629
1630     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
1631         with patch(
1632             "black.format_stdin_to_stdout",
1633             return_value=lambda *args, **kwargs: black.Changed.YES,
1634         ) as fsts:
1635             report = MagicMock()
1636             p = "foo.pyi"
1637             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1638             expected = Path(p)
1639             black.reformat_one(
1640                 path,
1641                 fast=True,
1642                 write_back=black.WriteBack.YES,
1643                 mode=DEFAULT_MODE,
1644                 report=report,
1645             )
1646             fsts.assert_called_once_with(
1647                 fast=True,
1648                 write_back=black.WriteBack.YES,
1649                 mode=replace(DEFAULT_MODE, is_pyi=True),
1650             )
1651             # __BLACK_STDIN_FILENAME__ should have been stripped
1652             report.done.assert_called_with(expected, black.Changed.YES)
1653
1654     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
1655         with patch(
1656             "black.format_stdin_to_stdout",
1657             return_value=lambda *args, **kwargs: black.Changed.YES,
1658         ) as fsts:
1659             report = MagicMock()
1660             # Even with an existing file, since we are forcing stdin, black
1661             # should output to stdout and not modify the file inplace
1662             p = Path(str(THIS_DIR / "data/collections.py"))
1663             # Make sure is_file actually returns True
1664             self.assertTrue(p.is_file())
1665             path = Path(f"__BLACK_STDIN_FILENAME__{p}")
1666             expected = Path(p)
1667             black.reformat_one(
1668                 path,
1669                 fast=True,
1670                 write_back=black.WriteBack.YES,
1671                 mode=DEFAULT_MODE,
1672                 report=report,
1673             )
1674             fsts.assert_called_once()
1675             # __BLACK_STDIN_FILENAME__ should have been stripped
1676             report.done.assert_called_with(expected, black.Changed.YES)
1677
1678     def test_gitignore_exclude(self) -> None:
1679         path = THIS_DIR / "data" / "include_exclude_tests"
1680         include = re.compile(r"\.pyi?$")
1681         exclude = re.compile(r"")
1682         report = black.Report()
1683         gitignore = PathSpec.from_lines(
1684             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1685         )
1686         sources: List[Path] = []
1687         expected = [
1688             Path(path / "b/dont_exclude/a.py"),
1689             Path(path / "b/dont_exclude/a.pyi"),
1690         ]
1691         this_abs = THIS_DIR.resolve()
1692         sources.extend(
1693             black.gen_python_files(
1694                 path.iterdir(),
1695                 this_abs,
1696                 include,
1697                 exclude,
1698                 None,
1699                 None,
1700                 report,
1701                 gitignore,
1702             )
1703         )
1704         self.assertEqual(sorted(expected), sorted(sources))
1705
1706     def test_nested_gitignore(self) -> None:
1707         path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
1708         include = re.compile(r"\.pyi?$")
1709         exclude = re.compile(r"")
1710         root_gitignore = black.files.get_gitignore(path)
1711         report = black.Report()
1712         expected: List[Path] = [
1713             Path(path / "x.py"),
1714             Path(path / "root/b.py"),
1715             Path(path / "root/c.py"),
1716             Path(path / "root/child/c.py"),
1717         ]
1718         this_abs = THIS_DIR.resolve()
1719         sources = list(
1720             black.gen_python_files(
1721                 path.iterdir(),
1722                 this_abs,
1723                 include,
1724                 exclude,
1725                 None,
1726                 None,
1727                 report,
1728                 root_gitignore,
1729             )
1730         )
1731         self.assertEqual(sorted(expected), sorted(sources))
1732
1733     def test_empty_include(self) -> None:
1734         path = THIS_DIR / "data" / "include_exclude_tests"
1735         report = black.Report()
1736         gitignore = PathSpec.from_lines("gitwildmatch", [])
1737         empty = re.compile(r"")
1738         sources: List[Path] = []
1739         expected = [
1740             Path(path / "b/exclude/a.pie"),
1741             Path(path / "b/exclude/a.py"),
1742             Path(path / "b/exclude/a.pyi"),
1743             Path(path / "b/dont_exclude/a.pie"),
1744             Path(path / "b/dont_exclude/a.py"),
1745             Path(path / "b/dont_exclude/a.pyi"),
1746             Path(path / "b/.definitely_exclude/a.pie"),
1747             Path(path / "b/.definitely_exclude/a.py"),
1748             Path(path / "b/.definitely_exclude/a.pyi"),
1749             Path(path / ".gitignore"),
1750             Path(path / "pyproject.toml"),
1751         ]
1752         this_abs = THIS_DIR.resolve()
1753         sources.extend(
1754             black.gen_python_files(
1755                 path.iterdir(),
1756                 this_abs,
1757                 empty,
1758                 re.compile(black.DEFAULT_EXCLUDES),
1759                 None,
1760                 None,
1761                 report,
1762                 gitignore,
1763             )
1764         )
1765         self.assertEqual(sorted(expected), sorted(sources))
1766
1767     def test_extend_exclude(self) -> None:
1768         path = THIS_DIR / "data" / "include_exclude_tests"
1769         report = black.Report()
1770         gitignore = PathSpec.from_lines("gitwildmatch", [])
1771         sources: List[Path] = []
1772         expected = [
1773             Path(path / "b/exclude/a.py"),
1774             Path(path / "b/dont_exclude/a.py"),
1775         ]
1776         this_abs = THIS_DIR.resolve()
1777         sources.extend(
1778             black.gen_python_files(
1779                 path.iterdir(),
1780                 this_abs,
1781                 re.compile(black.DEFAULT_INCLUDES),
1782                 re.compile(r"\.pyi$"),
1783                 re.compile(r"\.definitely_exclude"),
1784                 None,
1785                 report,
1786                 gitignore,
1787             )
1788         )
1789         self.assertEqual(sorted(expected), sorted(sources))
1790
1791     def test_invalid_cli_regex(self) -> None:
1792         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
1793             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1794
1795     def test_preserves_line_endings(self) -> None:
1796         with TemporaryDirectory() as workspace:
1797             test_file = Path(workspace) / "test.py"
1798             for nl in ["\n", "\r\n"]:
1799                 contents = nl.join(["def f(  ):", "    pass"])
1800                 test_file.write_bytes(contents.encode())
1801                 ff(test_file, write_back=black.WriteBack.YES)
1802                 updated_contents: bytes = test_file.read_bytes()
1803                 self.assertIn(nl.encode(), updated_contents)
1804                 if nl == "\n":
1805                     self.assertNotIn(b"\r\n", updated_contents)
1806
1807     def test_preserves_line_endings_via_stdin(self) -> None:
1808         for nl in ["\n", "\r\n"]:
1809             contents = nl.join(["def f(  ):", "    pass"])
1810             runner = BlackRunner()
1811             result = runner.invoke(
1812                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1813             )
1814             self.assertEqual(result.exit_code, 0)
1815             output = result.stdout_bytes
1816             self.assertIn(nl.encode("utf8"), output)
1817             if nl == "\n":
1818                 self.assertNotIn(b"\r\n", output)
1819
1820     def test_assert_equivalent_different_asts(self) -> None:
1821         with self.assertRaises(AssertionError):
1822             black.assert_equivalent("{}", "None")
1823
1824     def test_symlink_out_of_root_directory(self) -> None:
1825         path = MagicMock()
1826         root = THIS_DIR.resolve()
1827         child = MagicMock()
1828         include = re.compile(black.DEFAULT_INCLUDES)
1829         exclude = re.compile(black.DEFAULT_EXCLUDES)
1830         report = black.Report()
1831         gitignore = PathSpec.from_lines("gitwildmatch", [])
1832         # `child` should behave like a symlink which resolved path is clearly
1833         # outside of the `root` directory.
1834         path.iterdir.return_value = [child]
1835         child.resolve.return_value = Path("/a/b/c")
1836         child.as_posix.return_value = "/a/b/c"
1837         child.is_symlink.return_value = True
1838         try:
1839             list(
1840                 black.gen_python_files(
1841                     path.iterdir(),
1842                     root,
1843                     include,
1844                     exclude,
1845                     None,
1846                     None,
1847                     report,
1848                     gitignore,
1849                 )
1850             )
1851         except ValueError as ve:
1852             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1853         path.iterdir.assert_called_once()
1854         child.resolve.assert_called_once()
1855         child.is_symlink.assert_called_once()
1856         # `child` should behave like a strange file which resolved path is clearly
1857         # outside of the `root` directory.
1858         child.is_symlink.return_value = False
1859         with self.assertRaises(ValueError):
1860             list(
1861                 black.gen_python_files(
1862                     path.iterdir(),
1863                     root,
1864                     include,
1865                     exclude,
1866                     None,
1867                     None,
1868                     report,
1869                     gitignore,
1870                 )
1871             )
1872         path.iterdir.assert_called()
1873         self.assertEqual(path.iterdir.call_count, 2)
1874         child.resolve.assert_called()
1875         self.assertEqual(child.resolve.call_count, 2)
1876         child.is_symlink.assert_called()
1877         self.assertEqual(child.is_symlink.call_count, 2)
1878
1879     def test_shhh_click(self) -> None:
1880         try:
1881             from click import _unicodefun  # type: ignore
1882         except ModuleNotFoundError:
1883             self.skipTest("Incompatible Click version")
1884         if not hasattr(_unicodefun, "_verify_python3_env"):
1885             self.skipTest("Incompatible Click version")
1886         # First, let's see if Click is crashing with a preferred ASCII charset.
1887         with patch("locale.getpreferredencoding") as gpe:
1888             gpe.return_value = "ASCII"
1889             with self.assertRaises(RuntimeError):
1890                 _unicodefun._verify_python3_env()
1891         # Now, let's silence Click...
1892         black.patch_click()
1893         # ...and confirm it's silent.
1894         with patch("locale.getpreferredencoding") as gpe:
1895             gpe.return_value = "ASCII"
1896             try:
1897                 _unicodefun._verify_python3_env()
1898             except RuntimeError as re:
1899                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1900
1901     def test_root_logger_not_used_directly(self) -> None:
1902         def fail(*args: Any, **kwargs: Any) -> None:
1903             self.fail("Record created with root logger")
1904
1905         with patch.multiple(
1906             logging.root,
1907             debug=fail,
1908             info=fail,
1909             warning=fail,
1910             error=fail,
1911             critical=fail,
1912             log=fail,
1913         ):
1914             ff(THIS_DIR / "util.py")
1915
1916     def test_invalid_config_return_code(self) -> None:
1917         tmp_file = Path(black.dump_to_file())
1918         try:
1919             tmp_config = Path(black.dump_to_file())
1920             tmp_config.unlink()
1921             args = ["--config", str(tmp_config), str(tmp_file)]
1922             self.invokeBlack(args, exit_code=2, ignore_config=False)
1923         finally:
1924             tmp_file.unlink()
1925
1926     def test_parse_pyproject_toml(self) -> None:
1927         test_toml_file = THIS_DIR / "test.toml"
1928         config = black.parse_pyproject_toml(str(test_toml_file))
1929         self.assertEqual(config["verbose"], 1)
1930         self.assertEqual(config["check"], "no")
1931         self.assertEqual(config["diff"], "y")
1932         self.assertEqual(config["color"], True)
1933         self.assertEqual(config["line_length"], 79)
1934         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1935         self.assertEqual(config["exclude"], r"\.pyi?$")
1936         self.assertEqual(config["include"], r"\.py?$")
1937
1938     def test_read_pyproject_toml(self) -> None:
1939         test_toml_file = THIS_DIR / "test.toml"
1940         fake_ctx = FakeContext()
1941         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1942         config = fake_ctx.default_map
1943         self.assertEqual(config["verbose"], "1")
1944         self.assertEqual(config["check"], "no")
1945         self.assertEqual(config["diff"], "y")
1946         self.assertEqual(config["color"], "True")
1947         self.assertEqual(config["line_length"], "79")
1948         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1949         self.assertEqual(config["exclude"], r"\.pyi?$")
1950         self.assertEqual(config["include"], r"\.py?$")
1951
1952     def test_find_project_root(self) -> None:
1953         with TemporaryDirectory() as workspace:
1954             root = Path(workspace)
1955             test_dir = root / "test"
1956             test_dir.mkdir()
1957
1958             src_dir = root / "src"
1959             src_dir.mkdir()
1960
1961             root_pyproject = root / "pyproject.toml"
1962             root_pyproject.touch()
1963             src_pyproject = src_dir / "pyproject.toml"
1964             src_pyproject.touch()
1965             src_python = src_dir / "foo.py"
1966             src_python.touch()
1967
1968             self.assertEqual(
1969                 black.find_project_root((src_dir, test_dir)), root.resolve()
1970             )
1971             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1972             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1973
1974     @patch(
1975         "black.files.find_user_pyproject_toml",
1976         black.files.find_user_pyproject_toml.__wrapped__,
1977     )
1978     def test_find_user_pyproject_toml_linux(self) -> None:
1979         if system() == "Windows":
1980             return
1981
1982         # Test if XDG_CONFIG_HOME is checked
1983         with TemporaryDirectory() as workspace:
1984             tmp_user_config = Path(workspace) / "black"
1985             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
1986                 self.assertEqual(
1987                     black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
1988                 )
1989
1990         # Test fallback for XDG_CONFIG_HOME
1991         with patch.dict("os.environ"):
1992             os.environ.pop("XDG_CONFIG_HOME", None)
1993             fallback_user_config = Path("~/.config").expanduser() / "black"
1994             self.assertEqual(
1995                 black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
1996             )
1997
1998     def test_find_user_pyproject_toml_windows(self) -> None:
1999         if system() != "Windows":
2000             return
2001
2002         user_config_path = Path.home() / ".black"
2003         self.assertEqual(
2004             black.files.find_user_pyproject_toml(), user_config_path.resolve()
2005         )
2006
2007     def test_bpo_33660_workaround(self) -> None:
2008         if system() == "Windows":
2009             return
2010
2011         # https://bugs.python.org/issue33660
2012
2013         old_cwd = Path.cwd()
2014         try:
2015             root = Path("/")
2016             os.chdir(str(root))
2017             path = Path("workspace") / "project"
2018             report = black.Report(verbose=True)
2019             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2020             self.assertEqual(normalized_path, "workspace/project")
2021         finally:
2022             os.chdir(str(old_cwd))
2023
2024     def test_newline_comment_interaction(self) -> None:
2025         source = "class A:\\\r\n# type: ignore\n pass\n"
2026         output = black.format_str(source, mode=DEFAULT_MODE)
2027         black.assert_stable(source, output, mode=DEFAULT_MODE)
2028
2029     def test_bpo_2142_workaround(self) -> None:
2030
2031         # https://bugs.python.org/issue2142
2032
2033         source, _ = read_data("missing_final_newline.py")
2034         # read_data adds a trailing newline
2035         source = source.rstrip()
2036         expected, _ = read_data("missing_final_newline.diff")
2037         tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False))
2038         diff_header = re.compile(
2039             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
2040             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2041         )
2042         try:
2043             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
2044             self.assertEqual(result.exit_code, 0)
2045         finally:
2046             os.unlink(tmp_file)
2047         actual = result.output
2048         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2049         self.assertEqual(actual, expected)
2050
2051     @pytest.mark.python2
2052     def test_docstring_reformat_for_py27(self) -> None:
2053         """
2054         Check that stripping trailing whitespace from Python 2 docstrings
2055         doesn't trigger a "not equivalent to source" error
2056         """
2057         source = (
2058             b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
2059         )
2060         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
2061
2062         result = CliRunner().invoke(
2063             black.main,
2064             ["-", "-q", "--target-version=py27"],
2065             input=BytesIO(source),
2066         )
2067
2068         self.assertEqual(result.exit_code, 0)
2069         actual = result.output
2070         self.assertFormatEqual(actual, expected)
2071
2072
2073 with open(black.__file__, "r", encoding="utf-8") as _bf:
2074     black_source_lines = _bf.readlines()
2075
2076
2077 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2078     """Show function calls `from black/__init__.py` as they happen.
2079
2080     Register this with `sys.settrace()` in a test you're debugging.
2081     """
2082     if event != "call":
2083         return tracefunc
2084
2085     stack = len(inspect.stack()) - 19
2086     stack *= 2
2087     filename = frame.f_code.co_filename
2088     lineno = frame.f_lineno
2089     func_sig_lineno = lineno - 1
2090     funcname = black_source_lines[func_sig_lineno].strip()
2091     while funcname.startswith("@"):
2092         func_sig_lineno += 1
2093         funcname = black_source_lines[func_sig_lineno].strip()
2094     if "black/__init__.py" in filename:
2095         print(f"{' ' * stack}{lineno}:{funcname}")
2096     return tracefunc
2097
2098
2099 if __name__ == "__main__":
2100     unittest.main(module="test_black")