]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

e928dc0498457259d88c00429b5b085cb1a718fc
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import multiprocessing
3 import asyncio
4 import logging
5 from concurrent.futures import ThreadPoolExecutor
6 from contextlib import contextmanager
7 from dataclasses import replace
8 from functools import partial
9 import inspect
10 from io import BytesIO, TextIOWrapper
11 import os
12 from pathlib import Path
13 from platform import system
14 import regex as re
15 import sys
16 from tempfile import TemporaryDirectory
17 import types
18 from typing import (
19     Any,
20     BinaryIO,
21     Callable,
22     Dict,
23     Generator,
24     List,
25     Tuple,
26     Iterator,
27     TypeVar,
28 )
29 import unittest
30 from unittest.mock import patch, MagicMock
31
32 import click
33 from click import unstyle
34 from click.testing import CliRunner
35
36 import black
37 from black import Feature, TargetVersion
38
39 try:
40     import blackd
41     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
42     from aiohttp import web
43 except ImportError:
44     has_blackd_deps = False
45 else:
46     has_blackd_deps = True
47
48 from pathspec import PathSpec
49
50 # Import other test classes
51 from .test_primer import PrimerCLITests  # noqa: F401
52
53
54 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
55 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
56 fs = partial(black.format_str, mode=DEFAULT_MODE)
57 THIS_FILE = Path(__file__)
58 THIS_DIR = THIS_FILE.parent
59 PROJECT_ROOT = THIS_DIR.parent
60 DETERMINISTIC_HEADER = "[Deterministic header]"
61 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
62 PY36_ARGS = [
63     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
64 ]
65 T = TypeVar("T")
66 R = TypeVar("R")
67
68
69 def dump_to_stderr(*output: str) -> str:
70     return "\n" + "\n".join(output) + "\n"
71
72
73 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
74     """read_data('test_name') -> 'input', 'output'"""
75     if not name.endswith((".py", ".pyi", ".out", ".diff")):
76         name += ".py"
77     _input: List[str] = []
78     _output: List[str] = []
79     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
80     with open(base_dir / name, "r", encoding="utf8") as test:
81         lines = test.readlines()
82     result = _input
83     for line in lines:
84         line = line.replace(EMPTY_LINE, "")
85         if line.rstrip() == "# output":
86             result = _output
87             continue
88
89         result.append(line)
90     if _input and not _output:
91         # If there's no output marker, treat the entire file as already pre-formatted.
92         _output = _input[:]
93     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
94
95
96 @contextmanager
97 def cache_dir(exists: bool = True) -> Iterator[Path]:
98     with TemporaryDirectory() as workspace:
99         cache_dir = Path(workspace)
100         if not exists:
101             cache_dir = cache_dir / "new"
102         with patch("black.CACHE_DIR", cache_dir):
103             yield cache_dir
104
105
106 @contextmanager
107 def event_loop() -> Iterator[None]:
108     policy = asyncio.get_event_loop_policy()
109     loop = policy.new_event_loop()
110     asyncio.set_event_loop(loop)
111     try:
112         yield
113
114     finally:
115         loop.close()
116
117
118 @contextmanager
119 def skip_if_exception(e: str) -> Iterator[None]:
120     try:
121         yield
122     except Exception as exc:
123         if exc.__class__.__name__ == e:
124             unittest.skip(f"Encountered expected exception {exc}, skipping")
125         else:
126             raise
127
128
129 class FakeContext(click.Context):
130     """A fake click Context for when calling functions that need it."""
131
132     def __init__(self) -> None:
133         self.default_map: Dict[str, Any] = {}
134
135
136 class FakeParameter(click.Parameter):
137     """A fake click Parameter for when calling functions that need it."""
138
139     def __init__(self) -> None:
140         pass
141
142
143 class BlackRunner(CliRunner):
144     """Modify CliRunner so that stderr is not merged with stdout.
145
146     This is a hack that can be removed once we depend on Click 7.x"""
147
148     def __init__(self) -> None:
149         self.stderrbuf = BytesIO()
150         self.stdoutbuf = BytesIO()
151         self.stdout_bytes = b""
152         self.stderr_bytes = b""
153         super().__init__()
154
155     @contextmanager
156     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
157         with super().isolation(*args, **kwargs) as output:
158             try:
159                 hold_stderr = sys.stderr
160                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
161                 yield output
162             finally:
163                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
164                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
165                 sys.stderr = hold_stderr
166
167
168 class BlackTestCase(unittest.TestCase):
169     maxDiff = None
170     _diffThreshold = 2 ** 20
171
172     def assertFormatEqual(self, expected: str, actual: str) -> None:
173         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
174             bdv: black.DebugVisitor[Any]
175             black.out("Expected tree:", fg="green")
176             try:
177                 exp_node = black.lib2to3_parse(expected)
178                 bdv = black.DebugVisitor()
179                 list(bdv.visit(exp_node))
180             except Exception as ve:
181                 black.err(str(ve))
182             black.out("Actual tree:", fg="red")
183             try:
184                 exp_node = black.lib2to3_parse(actual)
185                 bdv = black.DebugVisitor()
186                 list(bdv.visit(exp_node))
187             except Exception as ve:
188                 black.err(str(ve))
189         self.assertMultiLineEqual(expected, actual)
190
191     def invokeBlack(
192         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
193     ) -> None:
194         runner = BlackRunner()
195         if ignore_config:
196             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
197         result = runner.invoke(black.main, args)
198         self.assertEqual(
199             result.exit_code,
200             exit_code,
201             msg=(
202                 f"Failed with args: {args}\n"
203                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
204                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
205                 f"exception: {result.exception}"
206             ),
207         )
208
209     @patch("black.dump_to_file", dump_to_stderr)
210     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
211         path = THIS_DIR.parent / name
212         source, expected = read_data(str(path), data=False)
213         actual = fs(source, mode=mode)
214         self.assertFormatEqual(expected, actual)
215         black.assert_equivalent(source, actual)
216         black.assert_stable(source, actual, mode)
217         self.assertFalse(ff(path))
218
219     @patch("black.dump_to_file", dump_to_stderr)
220     def test_empty(self) -> None:
221         source = expected = ""
222         actual = fs(source)
223         self.assertFormatEqual(expected, actual)
224         black.assert_equivalent(source, actual)
225         black.assert_stable(source, actual, DEFAULT_MODE)
226
227     def test_empty_ff(self) -> None:
228         expected = ""
229         tmp_file = Path(black.dump_to_file())
230         try:
231             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
232             with open(tmp_file, encoding="utf8") as f:
233                 actual = f.read()
234         finally:
235             os.unlink(tmp_file)
236         self.assertFormatEqual(expected, actual)
237
238     def test_self(self) -> None:
239         self.checkSourceFile("tests/test_black.py")
240
241     def test_black(self) -> None:
242         self.checkSourceFile("src/black/__init__.py")
243
244     def test_pygram(self) -> None:
245         self.checkSourceFile("src/blib2to3/pygram.py")
246
247     def test_pytree(self) -> None:
248         self.checkSourceFile("src/blib2to3/pytree.py")
249
250     def test_conv(self) -> None:
251         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
252
253     def test_driver(self) -> None:
254         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
255
256     def test_grammar(self) -> None:
257         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
258
259     def test_literals(self) -> None:
260         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
261
262     def test_parse(self) -> None:
263         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
264
265     def test_pgen(self) -> None:
266         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
267
268     def test_tokenize(self) -> None:
269         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
270
271     def test_token(self) -> None:
272         self.checkSourceFile("src/blib2to3/pgen2/token.py")
273
274     def test_setup(self) -> None:
275         self.checkSourceFile("setup.py")
276
277     def test_piping(self) -> None:
278         source, expected = read_data("src/black/__init__", data=False)
279         result = BlackRunner().invoke(
280             black.main,
281             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
282             input=BytesIO(source.encode("utf8")),
283         )
284         self.assertEqual(result.exit_code, 0)
285         self.assertFormatEqual(expected, result.output)
286         black.assert_equivalent(source, result.output)
287         black.assert_stable(source, result.output, DEFAULT_MODE)
288
289     def test_piping_diff(self) -> None:
290         diff_header = re.compile(
291             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
292             r"\+\d\d\d\d"
293         )
294         source, _ = read_data("expression.py")
295         expected, _ = read_data("expression.diff")
296         config = THIS_DIR / "data" / "empty_pyproject.toml"
297         args = [
298             "-",
299             "--fast",
300             f"--line-length={black.DEFAULT_LINE_LENGTH}",
301             "--diff",
302             f"--config={config}",
303         ]
304         result = BlackRunner().invoke(
305             black.main, args, input=BytesIO(source.encode("utf8"))
306         )
307         self.assertEqual(result.exit_code, 0)
308         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
309         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
310         self.assertEqual(expected, actual)
311
312     def test_piping_diff_with_color(self) -> None:
313         source, _ = read_data("expression.py")
314         config = THIS_DIR / "data" / "empty_pyproject.toml"
315         args = [
316             "-",
317             "--fast",
318             f"--line-length={black.DEFAULT_LINE_LENGTH}",
319             "--diff",
320             "--color",
321             f"--config={config}",
322         ]
323         result = BlackRunner().invoke(
324             black.main, args, input=BytesIO(source.encode("utf8"))
325         )
326         actual = result.output
327         # Again, the contents are checked in a different test, so only look for colors.
328         self.assertIn("\033[1;37m", actual)
329         self.assertIn("\033[36m", actual)
330         self.assertIn("\033[32m", actual)
331         self.assertIn("\033[31m", actual)
332         self.assertIn("\033[0m", actual)
333
334     @patch("black.dump_to_file", dump_to_stderr)
335     def test_function(self) -> None:
336         source, expected = read_data("function")
337         actual = fs(source)
338         self.assertFormatEqual(expected, actual)
339         black.assert_equivalent(source, actual)
340         black.assert_stable(source, actual, DEFAULT_MODE)
341
342     @patch("black.dump_to_file", dump_to_stderr)
343     def test_function2(self) -> None:
344         source, expected = read_data("function2")
345         actual = fs(source)
346         self.assertFormatEqual(expected, actual)
347         black.assert_equivalent(source, actual)
348         black.assert_stable(source, actual, DEFAULT_MODE)
349
350     @patch("black.dump_to_file", dump_to_stderr)
351     def _test_wip(self) -> None:
352         source, expected = read_data("wip")
353         sys.settrace(tracefunc)
354         mode = replace(
355             DEFAULT_MODE,
356             experimental_string_processing=False,
357             target_versions={black.TargetVersion.PY38},
358         )
359         actual = fs(source, mode=mode)
360         sys.settrace(None)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, black.FileMode())
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_function_trailing_comma(self) -> None:
367         source, expected = read_data("function_trailing_comma")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, DEFAULT_MODE)
372
373     @unittest.expectedFailure
374     @patch("black.dump_to_file", dump_to_stderr)
375     def test_trailing_comma_optional_parens_stability1(self) -> None:
376         source, _expected = read_data("trailing_comma_optional_parens1")
377         actual = fs(source)
378         black.assert_stable(source, actual, DEFAULT_MODE)
379
380     @unittest.expectedFailure
381     @patch("black.dump_to_file", dump_to_stderr)
382     def test_trailing_comma_optional_parens_stability2(self) -> None:
383         source, _expected = read_data("trailing_comma_optional_parens2")
384         actual = fs(source)
385         black.assert_stable(source, actual, DEFAULT_MODE)
386
387     @unittest.expectedFailure
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_trailing_comma_optional_parens_stability3(self) -> None:
390         source, _expected = read_data("trailing_comma_optional_parens3")
391         actual = fs(source)
392         black.assert_stable(source, actual, DEFAULT_MODE)
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_expression(self) -> None:
396         source, expected = read_data("expression")
397         actual = fs(source)
398         self.assertFormatEqual(expected, actual)
399         black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, DEFAULT_MODE)
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_pep_572(self) -> None:
404         source, expected = read_data("pep_572")
405         actual = fs(source)
406         self.assertFormatEqual(expected, actual)
407         black.assert_stable(source, actual, DEFAULT_MODE)
408         if sys.version_info >= (3, 8):
409             black.assert_equivalent(source, actual)
410
411     def test_pep_572_version_detection(self) -> None:
412         source, _ = read_data("pep_572")
413         root = black.lib2to3_parse(source)
414         features = black.get_features_used(root)
415         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
416         versions = black.detect_target_versions(root)
417         self.assertIn(black.TargetVersion.PY38, versions)
418
419     def test_expression_ff(self) -> None:
420         source, expected = read_data("expression")
421         tmp_file = Path(black.dump_to_file(source))
422         try:
423             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
424             with open(tmp_file, encoding="utf8") as f:
425                 actual = f.read()
426         finally:
427             os.unlink(tmp_file)
428         self.assertFormatEqual(expected, actual)
429         with patch("black.dump_to_file", dump_to_stderr):
430             black.assert_equivalent(source, actual)
431             black.assert_stable(source, actual, DEFAULT_MODE)
432
433     def test_expression_diff(self) -> None:
434         source, _ = read_data("expression.py")
435         expected, _ = read_data("expression.diff")
436         tmp_file = Path(black.dump_to_file(source))
437         diff_header = re.compile(
438             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
439             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
440         )
441         try:
442             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
443             self.assertEqual(result.exit_code, 0)
444         finally:
445             os.unlink(tmp_file)
446         actual = result.output
447         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
448         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
449         if expected != actual:
450             dump = black.dump_to_file(actual)
451             msg = (
452                 "Expected diff isn't equal to the actual. If you made changes to"
453                 " expression.py and this is an anticipated difference, overwrite"
454                 f" tests/data/expression.diff with {dump}"
455             )
456             self.assertEqual(expected, actual, msg)
457
458     def test_expression_diff_with_color(self) -> None:
459         source, _ = read_data("expression.py")
460         expected, _ = read_data("expression.diff")
461         tmp_file = Path(black.dump_to_file(source))
462         try:
463             result = BlackRunner().invoke(
464                 black.main, ["--diff", "--color", str(tmp_file)]
465             )
466         finally:
467             os.unlink(tmp_file)
468         actual = result.output
469         # We check the contents of the diff in `test_expression_diff`. All
470         # we need to check here is that color codes exist in the result.
471         self.assertIn("\033[1;37m", actual)
472         self.assertIn("\033[36m", actual)
473         self.assertIn("\033[32m", actual)
474         self.assertIn("\033[31m", actual)
475         self.assertIn("\033[0m", actual)
476
477     @patch("black.dump_to_file", dump_to_stderr)
478     def test_fstring(self) -> None:
479         source, expected = read_data("fstring")
480         actual = fs(source)
481         self.assertFormatEqual(expected, actual)
482         black.assert_equivalent(source, actual)
483         black.assert_stable(source, actual, DEFAULT_MODE)
484
485     @patch("black.dump_to_file", dump_to_stderr)
486     def test_pep_570(self) -> None:
487         source, expected = read_data("pep_570")
488         actual = fs(source)
489         self.assertFormatEqual(expected, actual)
490         black.assert_stable(source, actual, DEFAULT_MODE)
491         if sys.version_info >= (3, 8):
492             black.assert_equivalent(source, actual)
493
494     def test_detect_pos_only_arguments(self) -> None:
495         source, _ = read_data("pep_570")
496         root = black.lib2to3_parse(source)
497         features = black.get_features_used(root)
498         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
499         versions = black.detect_target_versions(root)
500         self.assertIn(black.TargetVersion.PY38, versions)
501
502     @patch("black.dump_to_file", dump_to_stderr)
503     def test_string_quotes(self) -> None:
504         source, expected = read_data("string_quotes")
505         actual = fs(source)
506         self.assertFormatEqual(expected, actual)
507         black.assert_equivalent(source, actual)
508         black.assert_stable(source, actual, DEFAULT_MODE)
509         mode = replace(DEFAULT_MODE, string_normalization=False)
510         not_normalized = fs(source, mode=mode)
511         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
512         black.assert_equivalent(source, not_normalized)
513         black.assert_stable(source, not_normalized, mode=mode)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_docstring(self) -> None:
517         source, expected = read_data("docstring")
518         actual = fs(source)
519         self.assertFormatEqual(expected, actual)
520         black.assert_equivalent(source, actual)
521         black.assert_stable(source, actual, DEFAULT_MODE)
522
523     @patch("black.dump_to_file", dump_to_stderr)
524     def test_docstring_no_string_normalization(self) -> None:
525         """Like test_docstring but with string normalization off."""
526         source, expected = read_data("docstring_no_string_normalization")
527         mode = replace(DEFAULT_MODE, string_normalization=False)
528         actual = fs(source, mode=mode)
529         self.assertFormatEqual(expected, actual)
530         black.assert_equivalent(source, actual)
531         black.assert_stable(source, actual, mode)
532
533     def test_long_strings(self) -> None:
534         """Tests for splitting long strings."""
535         source, expected = read_data("long_strings")
536         actual = fs(source)
537         self.assertFormatEqual(expected, actual)
538         black.assert_equivalent(source, actual)
539         black.assert_stable(source, actual, DEFAULT_MODE)
540
541     def test_long_strings_flag_disabled(self) -> None:
542         """Tests for turning off the string processing logic."""
543         source, expected = read_data("long_strings_flag_disabled")
544         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
545         actual = fs(source, mode=mode)
546         self.assertFormatEqual(expected, actual)
547         black.assert_stable(expected, actual, mode)
548
549     @patch("black.dump_to_file", dump_to_stderr)
550     def test_long_strings__edge_case(self) -> None:
551         """Edge-case tests for splitting long strings."""
552         source, expected = read_data("long_strings__edge_case")
553         actual = fs(source)
554         self.assertFormatEqual(expected, actual)
555         black.assert_equivalent(source, actual)
556         black.assert_stable(source, actual, DEFAULT_MODE)
557
558     @patch("black.dump_to_file", dump_to_stderr)
559     def test_long_strings__regression(self) -> None:
560         """Regression tests for splitting long strings."""
561         source, expected = read_data("long_strings__regression")
562         actual = fs(source)
563         self.assertFormatEqual(expected, actual)
564         black.assert_equivalent(source, actual)
565         black.assert_stable(source, actual, DEFAULT_MODE)
566
567     @patch("black.dump_to_file", dump_to_stderr)
568     def test_slices(self) -> None:
569         source, expected = read_data("slices")
570         actual = fs(source)
571         self.assertFormatEqual(expected, actual)
572         black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, DEFAULT_MODE)
574
575     @patch("black.dump_to_file", dump_to_stderr)
576     def test_percent_precedence(self) -> None:
577         source, expected = read_data("percent_precedence")
578         actual = fs(source)
579         self.assertFormatEqual(expected, actual)
580         black.assert_equivalent(source, actual)
581         black.assert_stable(source, actual, DEFAULT_MODE)
582
583     @patch("black.dump_to_file", dump_to_stderr)
584     def test_comments(self) -> None:
585         source, expected = read_data("comments")
586         actual = fs(source)
587         self.assertFormatEqual(expected, actual)
588         black.assert_equivalent(source, actual)
589         black.assert_stable(source, actual, DEFAULT_MODE)
590
591     @patch("black.dump_to_file", dump_to_stderr)
592     def test_comments2(self) -> None:
593         source, expected = read_data("comments2")
594         actual = fs(source)
595         self.assertFormatEqual(expected, actual)
596         black.assert_equivalent(source, actual)
597         black.assert_stable(source, actual, DEFAULT_MODE)
598
599     @patch("black.dump_to_file", dump_to_stderr)
600     def test_comments3(self) -> None:
601         source, expected = read_data("comments3")
602         actual = fs(source)
603         self.assertFormatEqual(expected, actual)
604         black.assert_equivalent(source, actual)
605         black.assert_stable(source, actual, DEFAULT_MODE)
606
607     @patch("black.dump_to_file", dump_to_stderr)
608     def test_comments4(self) -> None:
609         source, expected = read_data("comments4")
610         actual = fs(source)
611         self.assertFormatEqual(expected, actual)
612         black.assert_equivalent(source, actual)
613         black.assert_stable(source, actual, DEFAULT_MODE)
614
615     @patch("black.dump_to_file", dump_to_stderr)
616     def test_comments5(self) -> None:
617         source, expected = read_data("comments5")
618         actual = fs(source)
619         self.assertFormatEqual(expected, actual)
620         black.assert_equivalent(source, actual)
621         black.assert_stable(source, actual, DEFAULT_MODE)
622
623     @patch("black.dump_to_file", dump_to_stderr)
624     def test_comments6(self) -> None:
625         source, expected = read_data("comments6")
626         actual = fs(source)
627         self.assertFormatEqual(expected, actual)
628         black.assert_equivalent(source, actual)
629         black.assert_stable(source, actual, DEFAULT_MODE)
630
631     @patch("black.dump_to_file", dump_to_stderr)
632     def test_comments7(self) -> None:
633         source, expected = read_data("comments7")
634         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
635         actual = fs(source, mode=mode)
636         self.assertFormatEqual(expected, actual)
637         black.assert_equivalent(source, actual)
638         black.assert_stable(source, actual, DEFAULT_MODE)
639
640     @patch("black.dump_to_file", dump_to_stderr)
641     def test_comment_after_escaped_newline(self) -> None:
642         source, expected = read_data("comment_after_escaped_newline")
643         actual = fs(source)
644         self.assertFormatEqual(expected, actual)
645         black.assert_equivalent(source, actual)
646         black.assert_stable(source, actual, DEFAULT_MODE)
647
648     @patch("black.dump_to_file", dump_to_stderr)
649     def test_cantfit(self) -> None:
650         source, expected = read_data("cantfit")
651         actual = fs(source)
652         self.assertFormatEqual(expected, actual)
653         black.assert_equivalent(source, actual)
654         black.assert_stable(source, actual, DEFAULT_MODE)
655
656     @patch("black.dump_to_file", dump_to_stderr)
657     def test_import_spacing(self) -> None:
658         source, expected = read_data("import_spacing")
659         actual = fs(source)
660         self.assertFormatEqual(expected, actual)
661         black.assert_equivalent(source, actual)
662         black.assert_stable(source, actual, DEFAULT_MODE)
663
664     @patch("black.dump_to_file", dump_to_stderr)
665     def test_composition(self) -> None:
666         source, expected = read_data("composition")
667         actual = fs(source)
668         self.assertFormatEqual(expected, actual)
669         black.assert_equivalent(source, actual)
670         black.assert_stable(source, actual, DEFAULT_MODE)
671
672     @patch("black.dump_to_file", dump_to_stderr)
673     def test_composition_no_trailing_comma(self) -> None:
674         source, expected = read_data("composition_no_trailing_comma")
675         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
676         actual = fs(source, mode=mode)
677         self.assertFormatEqual(expected, actual)
678         black.assert_equivalent(source, actual)
679         black.assert_stable(source, actual, DEFAULT_MODE)
680
681     @patch("black.dump_to_file", dump_to_stderr)
682     def test_empty_lines(self) -> None:
683         source, expected = read_data("empty_lines")
684         actual = fs(source)
685         self.assertFormatEqual(expected, actual)
686         black.assert_equivalent(source, actual)
687         black.assert_stable(source, actual, DEFAULT_MODE)
688
689     @patch("black.dump_to_file", dump_to_stderr)
690     def test_remove_parens(self) -> None:
691         source, expected = read_data("remove_parens")
692         actual = fs(source)
693         self.assertFormatEqual(expected, actual)
694         black.assert_equivalent(source, actual)
695         black.assert_stable(source, actual, DEFAULT_MODE)
696
697     @patch("black.dump_to_file", dump_to_stderr)
698     def test_string_prefixes(self) -> None:
699         source, expected = read_data("string_prefixes")
700         actual = fs(source)
701         self.assertFormatEqual(expected, actual)
702         black.assert_equivalent(source, actual)
703         black.assert_stable(source, actual, DEFAULT_MODE)
704
705     @patch("black.dump_to_file", dump_to_stderr)
706     def test_numeric_literals(self) -> None:
707         source, expected = read_data("numeric_literals")
708         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
709         actual = fs(source, mode=mode)
710         self.assertFormatEqual(expected, actual)
711         black.assert_equivalent(source, actual)
712         black.assert_stable(source, actual, mode)
713
714     @patch("black.dump_to_file", dump_to_stderr)
715     def test_numeric_literals_ignoring_underscores(self) -> None:
716         source, expected = read_data("numeric_literals_skip_underscores")
717         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
718         actual = fs(source, mode=mode)
719         self.assertFormatEqual(expected, actual)
720         black.assert_equivalent(source, actual)
721         black.assert_stable(source, actual, mode)
722
723     @patch("black.dump_to_file", dump_to_stderr)
724     def test_numeric_literals_py2(self) -> None:
725         source, expected = read_data("numeric_literals_py2")
726         actual = fs(source)
727         self.assertFormatEqual(expected, actual)
728         black.assert_stable(source, actual, DEFAULT_MODE)
729
730     @patch("black.dump_to_file", dump_to_stderr)
731     def test_python2(self) -> None:
732         source, expected = read_data("python2")
733         actual = fs(source)
734         self.assertFormatEqual(expected, actual)
735         black.assert_equivalent(source, actual)
736         black.assert_stable(source, actual, DEFAULT_MODE)
737
738     @patch("black.dump_to_file", dump_to_stderr)
739     def test_python2_print_function(self) -> None:
740         source, expected = read_data("python2_print_function")
741         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
742         actual = fs(source, mode=mode)
743         self.assertFormatEqual(expected, actual)
744         black.assert_equivalent(source, actual)
745         black.assert_stable(source, actual, mode)
746
747     @patch("black.dump_to_file", dump_to_stderr)
748     def test_python2_unicode_literals(self) -> None:
749         source, expected = read_data("python2_unicode_literals")
750         actual = fs(source)
751         self.assertFormatEqual(expected, actual)
752         black.assert_equivalent(source, actual)
753         black.assert_stable(source, actual, DEFAULT_MODE)
754
755     @patch("black.dump_to_file", dump_to_stderr)
756     def test_stub(self) -> None:
757         mode = replace(DEFAULT_MODE, is_pyi=True)
758         source, expected = read_data("stub.pyi")
759         actual = fs(source, mode=mode)
760         self.assertFormatEqual(expected, actual)
761         black.assert_stable(source, actual, mode)
762
763     @patch("black.dump_to_file", dump_to_stderr)
764     def test_async_as_identifier(self) -> None:
765         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
766         source, expected = read_data("async_as_identifier")
767         actual = fs(source)
768         self.assertFormatEqual(expected, actual)
769         major, minor = sys.version_info[:2]
770         if major < 3 or (major <= 3 and minor < 7):
771             black.assert_equivalent(source, actual)
772         black.assert_stable(source, actual, DEFAULT_MODE)
773         # ensure black can parse this when the target is 3.6
774         self.invokeBlack([str(source_path), "--target-version", "py36"])
775         # but not on 3.7, because async/await is no longer an identifier
776         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
777
778     @patch("black.dump_to_file", dump_to_stderr)
779     def test_python37(self) -> None:
780         source_path = (THIS_DIR / "data" / "python37.py").resolve()
781         source, expected = read_data("python37")
782         actual = fs(source)
783         self.assertFormatEqual(expected, actual)
784         major, minor = sys.version_info[:2]
785         if major > 3 or (major == 3 and minor >= 7):
786             black.assert_equivalent(source, actual)
787         black.assert_stable(source, actual, DEFAULT_MODE)
788         # ensure black can parse this when the target is 3.7
789         self.invokeBlack([str(source_path), "--target-version", "py37"])
790         # but not on 3.6, because we use async as a reserved keyword
791         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
792
793     @patch("black.dump_to_file", dump_to_stderr)
794     def test_python38(self) -> None:
795         source, expected = read_data("python38")
796         actual = fs(source)
797         self.assertFormatEqual(expected, actual)
798         major, minor = sys.version_info[:2]
799         if major > 3 or (major == 3 and minor >= 8):
800             black.assert_equivalent(source, actual)
801         black.assert_stable(source, actual, DEFAULT_MODE)
802
803     @patch("black.dump_to_file", dump_to_stderr)
804     def test_fmtonoff(self) -> None:
805         source, expected = read_data("fmtonoff")
806         actual = fs(source)
807         self.assertFormatEqual(expected, actual)
808         black.assert_equivalent(source, actual)
809         black.assert_stable(source, actual, DEFAULT_MODE)
810
811     @patch("black.dump_to_file", dump_to_stderr)
812     def test_fmtonoff2(self) -> None:
813         source, expected = read_data("fmtonoff2")
814         actual = fs(source)
815         self.assertFormatEqual(expected, actual)
816         black.assert_equivalent(source, actual)
817         black.assert_stable(source, actual, DEFAULT_MODE)
818
819     @patch("black.dump_to_file", dump_to_stderr)
820     def test_fmtonoff3(self) -> None:
821         source, expected = read_data("fmtonoff3")
822         actual = fs(source)
823         self.assertFormatEqual(expected, actual)
824         black.assert_equivalent(source, actual)
825         black.assert_stable(source, actual, DEFAULT_MODE)
826
827     @patch("black.dump_to_file", dump_to_stderr)
828     def test_fmtonoff4(self) -> None:
829         source, expected = read_data("fmtonoff4")
830         actual = fs(source)
831         self.assertFormatEqual(expected, actual)
832         black.assert_equivalent(source, actual)
833         black.assert_stable(source, actual, DEFAULT_MODE)
834
835     @patch("black.dump_to_file", dump_to_stderr)
836     def test_remove_empty_parentheses_after_class(self) -> None:
837         source, expected = read_data("class_blank_parentheses")
838         actual = fs(source)
839         self.assertFormatEqual(expected, actual)
840         black.assert_equivalent(source, actual)
841         black.assert_stable(source, actual, DEFAULT_MODE)
842
843     @patch("black.dump_to_file", dump_to_stderr)
844     def test_new_line_between_class_and_code(self) -> None:
845         source, expected = read_data("class_methods_new_line")
846         actual = fs(source)
847         self.assertFormatEqual(expected, actual)
848         black.assert_equivalent(source, actual)
849         black.assert_stable(source, actual, DEFAULT_MODE)
850
851     @patch("black.dump_to_file", dump_to_stderr)
852     def test_bracket_match(self) -> None:
853         source, expected = read_data("bracketmatch")
854         actual = fs(source)
855         self.assertFormatEqual(expected, actual)
856         black.assert_equivalent(source, actual)
857         black.assert_stable(source, actual, DEFAULT_MODE)
858
859     @patch("black.dump_to_file", dump_to_stderr)
860     def test_tuple_assign(self) -> None:
861         source, expected = read_data("tupleassign")
862         actual = fs(source)
863         self.assertFormatEqual(expected, actual)
864         black.assert_equivalent(source, actual)
865         black.assert_stable(source, actual, DEFAULT_MODE)
866
867     @patch("black.dump_to_file", dump_to_stderr)
868     def test_beginning_backslash(self) -> None:
869         source, expected = read_data("beginning_backslash")
870         actual = fs(source)
871         self.assertFormatEqual(expected, actual)
872         black.assert_equivalent(source, actual)
873         black.assert_stable(source, actual, DEFAULT_MODE)
874
875     def test_tab_comment_indentation(self) -> None:
876         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
877         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
878         self.assertFormatEqual(contents_spc, fs(contents_spc))
879         self.assertFormatEqual(contents_spc, fs(contents_tab))
880
881         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
882         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
883         self.assertFormatEqual(contents_spc, fs(contents_spc))
884         self.assertFormatEqual(contents_spc, fs(contents_tab))
885
886         # mixed tabs and spaces (valid Python 2 code)
887         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
888         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
889         self.assertFormatEqual(contents_spc, fs(contents_spc))
890         self.assertFormatEqual(contents_spc, fs(contents_tab))
891
892         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
893         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
894         self.assertFormatEqual(contents_spc, fs(contents_spc))
895         self.assertFormatEqual(contents_spc, fs(contents_tab))
896
897     def test_report_verbose(self) -> None:
898         report = black.Report(verbose=True)
899         out_lines = []
900         err_lines = []
901
902         def out(msg: str, **kwargs: Any) -> None:
903             out_lines.append(msg)
904
905         def err(msg: str, **kwargs: Any) -> None:
906             err_lines.append(msg)
907
908         with patch("black.out", out), patch("black.err", err):
909             report.done(Path("f1"), black.Changed.NO)
910             self.assertEqual(len(out_lines), 1)
911             self.assertEqual(len(err_lines), 0)
912             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
913             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
914             self.assertEqual(report.return_code, 0)
915             report.done(Path("f2"), black.Changed.YES)
916             self.assertEqual(len(out_lines), 2)
917             self.assertEqual(len(err_lines), 0)
918             self.assertEqual(out_lines[-1], "reformatted f2")
919             self.assertEqual(
920                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
921             )
922             report.done(Path("f3"), black.Changed.CACHED)
923             self.assertEqual(len(out_lines), 3)
924             self.assertEqual(len(err_lines), 0)
925             self.assertEqual(
926                 out_lines[-1], "f3 wasn't modified on disk since last run."
927             )
928             self.assertEqual(
929                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
930             )
931             self.assertEqual(report.return_code, 0)
932             report.check = True
933             self.assertEqual(report.return_code, 1)
934             report.check = False
935             report.failed(Path("e1"), "boom")
936             self.assertEqual(len(out_lines), 3)
937             self.assertEqual(len(err_lines), 1)
938             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
939             self.assertEqual(
940                 unstyle(str(report)),
941                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
942                 " reformat.",
943             )
944             self.assertEqual(report.return_code, 123)
945             report.done(Path("f3"), black.Changed.YES)
946             self.assertEqual(len(out_lines), 4)
947             self.assertEqual(len(err_lines), 1)
948             self.assertEqual(out_lines[-1], "reformatted f3")
949             self.assertEqual(
950                 unstyle(str(report)),
951                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
952                 " reformat.",
953             )
954             self.assertEqual(report.return_code, 123)
955             report.failed(Path("e2"), "boom")
956             self.assertEqual(len(out_lines), 4)
957             self.assertEqual(len(err_lines), 2)
958             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
959             self.assertEqual(
960                 unstyle(str(report)),
961                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
962                 " reformat.",
963             )
964             self.assertEqual(report.return_code, 123)
965             report.path_ignored(Path("wat"), "no match")
966             self.assertEqual(len(out_lines), 5)
967             self.assertEqual(len(err_lines), 2)
968             self.assertEqual(out_lines[-1], "wat ignored: no match")
969             self.assertEqual(
970                 unstyle(str(report)),
971                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
972                 " reformat.",
973             )
974             self.assertEqual(report.return_code, 123)
975             report.done(Path("f4"), black.Changed.NO)
976             self.assertEqual(len(out_lines), 6)
977             self.assertEqual(len(err_lines), 2)
978             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
979             self.assertEqual(
980                 unstyle(str(report)),
981                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
982                 " reformat.",
983             )
984             self.assertEqual(report.return_code, 123)
985             report.check = True
986             self.assertEqual(
987                 unstyle(str(report)),
988                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
989                 " would fail to reformat.",
990             )
991             report.check = False
992             report.diff = True
993             self.assertEqual(
994                 unstyle(str(report)),
995                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
996                 " would fail to reformat.",
997             )
998
999     def test_report_quiet(self) -> None:
1000         report = black.Report(quiet=True)
1001         out_lines = []
1002         err_lines = []
1003
1004         def out(msg: str, **kwargs: Any) -> None:
1005             out_lines.append(msg)
1006
1007         def err(msg: str, **kwargs: Any) -> None:
1008             err_lines.append(msg)
1009
1010         with patch("black.out", out), patch("black.err", err):
1011             report.done(Path("f1"), black.Changed.NO)
1012             self.assertEqual(len(out_lines), 0)
1013             self.assertEqual(len(err_lines), 0)
1014             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1015             self.assertEqual(report.return_code, 0)
1016             report.done(Path("f2"), black.Changed.YES)
1017             self.assertEqual(len(out_lines), 0)
1018             self.assertEqual(len(err_lines), 0)
1019             self.assertEqual(
1020                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1021             )
1022             report.done(Path("f3"), black.Changed.CACHED)
1023             self.assertEqual(len(out_lines), 0)
1024             self.assertEqual(len(err_lines), 0)
1025             self.assertEqual(
1026                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1027             )
1028             self.assertEqual(report.return_code, 0)
1029             report.check = True
1030             self.assertEqual(report.return_code, 1)
1031             report.check = False
1032             report.failed(Path("e1"), "boom")
1033             self.assertEqual(len(out_lines), 0)
1034             self.assertEqual(len(err_lines), 1)
1035             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1036             self.assertEqual(
1037                 unstyle(str(report)),
1038                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1039                 " reformat.",
1040             )
1041             self.assertEqual(report.return_code, 123)
1042             report.done(Path("f3"), black.Changed.YES)
1043             self.assertEqual(len(out_lines), 0)
1044             self.assertEqual(len(err_lines), 1)
1045             self.assertEqual(
1046                 unstyle(str(report)),
1047                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1048                 " reformat.",
1049             )
1050             self.assertEqual(report.return_code, 123)
1051             report.failed(Path("e2"), "boom")
1052             self.assertEqual(len(out_lines), 0)
1053             self.assertEqual(len(err_lines), 2)
1054             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1055             self.assertEqual(
1056                 unstyle(str(report)),
1057                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1058                 " reformat.",
1059             )
1060             self.assertEqual(report.return_code, 123)
1061             report.path_ignored(Path("wat"), "no match")
1062             self.assertEqual(len(out_lines), 0)
1063             self.assertEqual(len(err_lines), 2)
1064             self.assertEqual(
1065                 unstyle(str(report)),
1066                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1067                 " reformat.",
1068             )
1069             self.assertEqual(report.return_code, 123)
1070             report.done(Path("f4"), black.Changed.NO)
1071             self.assertEqual(len(out_lines), 0)
1072             self.assertEqual(len(err_lines), 2)
1073             self.assertEqual(
1074                 unstyle(str(report)),
1075                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1076                 " reformat.",
1077             )
1078             self.assertEqual(report.return_code, 123)
1079             report.check = True
1080             self.assertEqual(
1081                 unstyle(str(report)),
1082                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1083                 " would fail to reformat.",
1084             )
1085             report.check = False
1086             report.diff = True
1087             self.assertEqual(
1088                 unstyle(str(report)),
1089                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1090                 " would fail to reformat.",
1091             )
1092
1093     def test_report_normal(self) -> None:
1094         report = black.Report()
1095         out_lines = []
1096         err_lines = []
1097
1098         def out(msg: str, **kwargs: Any) -> None:
1099             out_lines.append(msg)
1100
1101         def err(msg: str, **kwargs: Any) -> None:
1102             err_lines.append(msg)
1103
1104         with patch("black.out", out), patch("black.err", err):
1105             report.done(Path("f1"), black.Changed.NO)
1106             self.assertEqual(len(out_lines), 0)
1107             self.assertEqual(len(err_lines), 0)
1108             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1109             self.assertEqual(report.return_code, 0)
1110             report.done(Path("f2"), black.Changed.YES)
1111             self.assertEqual(len(out_lines), 1)
1112             self.assertEqual(len(err_lines), 0)
1113             self.assertEqual(out_lines[-1], "reformatted f2")
1114             self.assertEqual(
1115                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1116             )
1117             report.done(Path("f3"), black.Changed.CACHED)
1118             self.assertEqual(len(out_lines), 1)
1119             self.assertEqual(len(err_lines), 0)
1120             self.assertEqual(out_lines[-1], "reformatted f2")
1121             self.assertEqual(
1122                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1123             )
1124             self.assertEqual(report.return_code, 0)
1125             report.check = True
1126             self.assertEqual(report.return_code, 1)
1127             report.check = False
1128             report.failed(Path("e1"), "boom")
1129             self.assertEqual(len(out_lines), 1)
1130             self.assertEqual(len(err_lines), 1)
1131             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1132             self.assertEqual(
1133                 unstyle(str(report)),
1134                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1135                 " reformat.",
1136             )
1137             self.assertEqual(report.return_code, 123)
1138             report.done(Path("f3"), black.Changed.YES)
1139             self.assertEqual(len(out_lines), 2)
1140             self.assertEqual(len(err_lines), 1)
1141             self.assertEqual(out_lines[-1], "reformatted f3")
1142             self.assertEqual(
1143                 unstyle(str(report)),
1144                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1145                 " reformat.",
1146             )
1147             self.assertEqual(report.return_code, 123)
1148             report.failed(Path("e2"), "boom")
1149             self.assertEqual(len(out_lines), 2)
1150             self.assertEqual(len(err_lines), 2)
1151             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1152             self.assertEqual(
1153                 unstyle(str(report)),
1154                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1155                 " reformat.",
1156             )
1157             self.assertEqual(report.return_code, 123)
1158             report.path_ignored(Path("wat"), "no match")
1159             self.assertEqual(len(out_lines), 2)
1160             self.assertEqual(len(err_lines), 2)
1161             self.assertEqual(
1162                 unstyle(str(report)),
1163                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1164                 " reformat.",
1165             )
1166             self.assertEqual(report.return_code, 123)
1167             report.done(Path("f4"), black.Changed.NO)
1168             self.assertEqual(len(out_lines), 2)
1169             self.assertEqual(len(err_lines), 2)
1170             self.assertEqual(
1171                 unstyle(str(report)),
1172                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1173                 " reformat.",
1174             )
1175             self.assertEqual(report.return_code, 123)
1176             report.check = True
1177             self.assertEqual(
1178                 unstyle(str(report)),
1179                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1180                 " would fail to reformat.",
1181             )
1182             report.check = False
1183             report.diff = True
1184             self.assertEqual(
1185                 unstyle(str(report)),
1186                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1187                 " would fail to reformat.",
1188             )
1189
1190     def test_lib2to3_parse(self) -> None:
1191         with self.assertRaises(black.InvalidInput):
1192             black.lib2to3_parse("invalid syntax")
1193
1194         straddling = "x + y"
1195         black.lib2to3_parse(straddling)
1196         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1197         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1198         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1199
1200         py2_only = "print x"
1201         black.lib2to3_parse(py2_only)
1202         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1203         with self.assertRaises(black.InvalidInput):
1204             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1205         with self.assertRaises(black.InvalidInput):
1206             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1207
1208         py3_only = "exec(x, end=y)"
1209         black.lib2to3_parse(py3_only)
1210         with self.assertRaises(black.InvalidInput):
1211             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1212         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1213         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1214
1215     def test_get_features_used(self) -> None:
1216         node = black.lib2to3_parse("def f(*, arg): ...\n")
1217         self.assertEqual(black.get_features_used(node), set())
1218         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1219         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1220         node = black.lib2to3_parse("f(*arg,)\n")
1221         self.assertEqual(
1222             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1223         )
1224         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1225         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1226         node = black.lib2to3_parse("123_456\n")
1227         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1228         node = black.lib2to3_parse("123456\n")
1229         self.assertEqual(black.get_features_used(node), set())
1230         source, expected = read_data("function")
1231         node = black.lib2to3_parse(source)
1232         expected_features = {
1233             Feature.TRAILING_COMMA_IN_CALL,
1234             Feature.TRAILING_COMMA_IN_DEF,
1235             Feature.F_STRINGS,
1236         }
1237         self.assertEqual(black.get_features_used(node), expected_features)
1238         node = black.lib2to3_parse(expected)
1239         self.assertEqual(black.get_features_used(node), expected_features)
1240         source, expected = read_data("expression")
1241         node = black.lib2to3_parse(source)
1242         self.assertEqual(black.get_features_used(node), set())
1243         node = black.lib2to3_parse(expected)
1244         self.assertEqual(black.get_features_used(node), set())
1245
1246     def test_get_future_imports(self) -> None:
1247         node = black.lib2to3_parse("\n")
1248         self.assertEqual(set(), black.get_future_imports(node))
1249         node = black.lib2to3_parse("from __future__ import black\n")
1250         self.assertEqual({"black"}, black.get_future_imports(node))
1251         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1252         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1253         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1254         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1255         node = black.lib2to3_parse(
1256             "from __future__ import multiple\nfrom __future__ import imports\n"
1257         )
1258         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1259         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1260         self.assertEqual({"black"}, black.get_future_imports(node))
1261         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1262         self.assertEqual({"black"}, black.get_future_imports(node))
1263         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1264         self.assertEqual(set(), black.get_future_imports(node))
1265         node = black.lib2to3_parse("from some.module import black\n")
1266         self.assertEqual(set(), black.get_future_imports(node))
1267         node = black.lib2to3_parse(
1268             "from __future__ import unicode_literals as _unicode_literals"
1269         )
1270         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1271         node = black.lib2to3_parse(
1272             "from __future__ import unicode_literals as _lol, print"
1273         )
1274         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1275
1276     def test_debug_visitor(self) -> None:
1277         source, _ = read_data("debug_visitor.py")
1278         expected, _ = read_data("debug_visitor.out")
1279         out_lines = []
1280         err_lines = []
1281
1282         def out(msg: str, **kwargs: Any) -> None:
1283             out_lines.append(msg)
1284
1285         def err(msg: str, **kwargs: Any) -> None:
1286             err_lines.append(msg)
1287
1288         with patch("black.out", out), patch("black.err", err):
1289             black.DebugVisitor.show(source)
1290         actual = "\n".join(out_lines) + "\n"
1291         log_name = ""
1292         if expected != actual:
1293             log_name = black.dump_to_file(*out_lines)
1294         self.assertEqual(
1295             expected,
1296             actual,
1297             f"AST print out is different. Actual version dumped to {log_name}",
1298         )
1299
1300     def test_format_file_contents(self) -> None:
1301         empty = ""
1302         mode = DEFAULT_MODE
1303         with self.assertRaises(black.NothingChanged):
1304             black.format_file_contents(empty, mode=mode, fast=False)
1305         just_nl = "\n"
1306         with self.assertRaises(black.NothingChanged):
1307             black.format_file_contents(just_nl, mode=mode, fast=False)
1308         same = "j = [1, 2, 3]\n"
1309         with self.assertRaises(black.NothingChanged):
1310             black.format_file_contents(same, mode=mode, fast=False)
1311         different = "j = [1,2,3]"
1312         expected = same
1313         actual = black.format_file_contents(different, mode=mode, fast=False)
1314         self.assertEqual(expected, actual)
1315         invalid = "return if you can"
1316         with self.assertRaises(black.InvalidInput) as e:
1317             black.format_file_contents(invalid, mode=mode, fast=False)
1318         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1319
1320     def test_endmarker(self) -> None:
1321         n = black.lib2to3_parse("\n")
1322         self.assertEqual(n.type, black.syms.file_input)
1323         self.assertEqual(len(n.children), 1)
1324         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1325
1326     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1327     def test_assertFormatEqual(self) -> None:
1328         out_lines = []
1329         err_lines = []
1330
1331         def out(msg: str, **kwargs: Any) -> None:
1332             out_lines.append(msg)
1333
1334         def err(msg: str, **kwargs: Any) -> None:
1335             err_lines.append(msg)
1336
1337         with patch("black.out", out), patch("black.err", err):
1338             with self.assertRaises(AssertionError):
1339                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1340
1341         out_str = "".join(out_lines)
1342         self.assertTrue("Expected tree:" in out_str)
1343         self.assertTrue("Actual tree:" in out_str)
1344         self.assertEqual("".join(err_lines), "")
1345
1346     def test_cache_broken_file(self) -> None:
1347         mode = DEFAULT_MODE
1348         with cache_dir() as workspace:
1349             cache_file = black.get_cache_file(mode)
1350             with cache_file.open("w") as fobj:
1351                 fobj.write("this is not a pickle")
1352             self.assertEqual(black.read_cache(mode), {})
1353             src = (workspace / "test.py").resolve()
1354             with src.open("w") as fobj:
1355                 fobj.write("print('hello')")
1356             self.invokeBlack([str(src)])
1357             cache = black.read_cache(mode)
1358             self.assertIn(src, cache)
1359
1360     def test_cache_single_file_already_cached(self) -> None:
1361         mode = DEFAULT_MODE
1362         with cache_dir() as workspace:
1363             src = (workspace / "test.py").resolve()
1364             with src.open("w") as fobj:
1365                 fobj.write("print('hello')")
1366             black.write_cache({}, [src], mode)
1367             self.invokeBlack([str(src)])
1368             with src.open("r") as fobj:
1369                 self.assertEqual(fobj.read(), "print('hello')")
1370
1371     @event_loop()
1372     def test_cache_multiple_files(self) -> None:
1373         mode = DEFAULT_MODE
1374         with cache_dir() as workspace, patch(
1375             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1376         ):
1377             one = (workspace / "one.py").resolve()
1378             with one.open("w") as fobj:
1379                 fobj.write("print('hello')")
1380             two = (workspace / "two.py").resolve()
1381             with two.open("w") as fobj:
1382                 fobj.write("print('hello')")
1383             black.write_cache({}, [one], mode)
1384             self.invokeBlack([str(workspace)])
1385             with one.open("r") as fobj:
1386                 self.assertEqual(fobj.read(), "print('hello')")
1387             with two.open("r") as fobj:
1388                 self.assertEqual(fobj.read(), 'print("hello")\n')
1389             cache = black.read_cache(mode)
1390             self.assertIn(one, cache)
1391             self.assertIn(two, cache)
1392
1393     def test_no_cache_when_writeback_diff(self) -> None:
1394         mode = DEFAULT_MODE
1395         with cache_dir() as workspace:
1396             src = (workspace / "test.py").resolve()
1397             with src.open("w") as fobj:
1398                 fobj.write("print('hello')")
1399             with patch("black.read_cache") as read_cache, patch(
1400                 "black.write_cache"
1401             ) as write_cache:
1402                 self.invokeBlack([str(src), "--diff"])
1403                 cache_file = black.get_cache_file(mode)
1404                 self.assertFalse(cache_file.exists())
1405                 write_cache.assert_not_called()
1406                 read_cache.assert_not_called()
1407
1408     def test_no_cache_when_writeback_color_diff(self) -> None:
1409         mode = DEFAULT_MODE
1410         with cache_dir() as workspace:
1411             src = (workspace / "test.py").resolve()
1412             with src.open("w") as fobj:
1413                 fobj.write("print('hello')")
1414             with patch("black.read_cache") as read_cache, patch(
1415                 "black.write_cache"
1416             ) as write_cache:
1417                 self.invokeBlack([str(src), "--diff", "--color"])
1418                 cache_file = black.get_cache_file(mode)
1419                 self.assertFalse(cache_file.exists())
1420                 write_cache.assert_not_called()
1421                 read_cache.assert_not_called()
1422
1423     @event_loop()
1424     def test_output_locking_when_writeback_diff(self) -> None:
1425         with cache_dir() as workspace:
1426             for tag in range(0, 4):
1427                 src = (workspace / f"test{tag}.py").resolve()
1428                 with src.open("w") as fobj:
1429                     fobj.write("print('hello')")
1430             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1431                 self.invokeBlack(["--diff", str(workspace)], exit_code=0)
1432                 # this isn't quite doing what we want, but if it _isn't_
1433                 # called then we cannot be using the lock it provides
1434                 mgr.assert_called()
1435
1436     @event_loop()
1437     def test_output_locking_when_writeback_color_diff(self) -> None:
1438         with cache_dir() as workspace:
1439             for tag in range(0, 4):
1440                 src = (workspace / f"test{tag}.py").resolve()
1441                 with src.open("w") as fobj:
1442                     fobj.write("print('hello')")
1443             with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
1444                 self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
1445                 # this isn't quite doing what we want, but if it _isn't_
1446                 # called then we cannot be using the lock it provides
1447                 mgr.assert_called()
1448
1449     def test_no_cache_when_stdin(self) -> None:
1450         mode = DEFAULT_MODE
1451         with cache_dir():
1452             result = CliRunner().invoke(
1453                 black.main, ["-"], input=BytesIO(b"print('hello')")
1454             )
1455             self.assertEqual(result.exit_code, 0)
1456             cache_file = black.get_cache_file(mode)
1457             self.assertFalse(cache_file.exists())
1458
1459     def test_read_cache_no_cachefile(self) -> None:
1460         mode = DEFAULT_MODE
1461         with cache_dir():
1462             self.assertEqual(black.read_cache(mode), {})
1463
1464     def test_write_cache_read_cache(self) -> None:
1465         mode = DEFAULT_MODE
1466         with cache_dir() as workspace:
1467             src = (workspace / "test.py").resolve()
1468             src.touch()
1469             black.write_cache({}, [src], mode)
1470             cache = black.read_cache(mode)
1471             self.assertIn(src, cache)
1472             self.assertEqual(cache[src], black.get_cache_info(src))
1473
1474     def test_filter_cached(self) -> None:
1475         with TemporaryDirectory() as workspace:
1476             path = Path(workspace)
1477             uncached = (path / "uncached").resolve()
1478             cached = (path / "cached").resolve()
1479             cached_but_changed = (path / "changed").resolve()
1480             uncached.touch()
1481             cached.touch()
1482             cached_but_changed.touch()
1483             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1484             todo, done = black.filter_cached(
1485                 cache, {uncached, cached, cached_but_changed}
1486             )
1487             self.assertEqual(todo, {uncached, cached_but_changed})
1488             self.assertEqual(done, {cached})
1489
1490     def test_write_cache_creates_directory_if_needed(self) -> None:
1491         mode = DEFAULT_MODE
1492         with cache_dir(exists=False) as workspace:
1493             self.assertFalse(workspace.exists())
1494             black.write_cache({}, [], mode)
1495             self.assertTrue(workspace.exists())
1496
1497     @event_loop()
1498     def test_failed_formatting_does_not_get_cached(self) -> None:
1499         mode = DEFAULT_MODE
1500         with cache_dir() as workspace, patch(
1501             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1502         ):
1503             failing = (workspace / "failing.py").resolve()
1504             with failing.open("w") as fobj:
1505                 fobj.write("not actually python")
1506             clean = (workspace / "clean.py").resolve()
1507             with clean.open("w") as fobj:
1508                 fobj.write('print("hello")\n')
1509             self.invokeBlack([str(workspace)], exit_code=123)
1510             cache = black.read_cache(mode)
1511             self.assertNotIn(failing, cache)
1512             self.assertIn(clean, cache)
1513
1514     def test_write_cache_write_fail(self) -> None:
1515         mode = DEFAULT_MODE
1516         with cache_dir(), patch.object(Path, "open") as mock:
1517             mock.side_effect = OSError
1518             black.write_cache({}, [], mode)
1519
1520     @event_loop()
1521     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1522     def test_works_in_mono_process_only_environment(self) -> None:
1523         with cache_dir() as workspace:
1524             for f in [
1525                 (workspace / "one.py").resolve(),
1526                 (workspace / "two.py").resolve(),
1527             ]:
1528                 f.write_text('print("hello")\n')
1529             self.invokeBlack([str(workspace)])
1530
1531     @event_loop()
1532     def test_check_diff_use_together(self) -> None:
1533         with cache_dir():
1534             # Files which will be reformatted.
1535             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1536             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1537             # Files which will not be reformatted.
1538             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1539             self.invokeBlack([str(src2), "--diff", "--check"])
1540             # Multi file command.
1541             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1542
1543     def test_no_files(self) -> None:
1544         with cache_dir():
1545             # Without an argument, black exits with error code 0.
1546             self.invokeBlack([])
1547
1548     def test_broken_symlink(self) -> None:
1549         with cache_dir() as workspace:
1550             symlink = workspace / "broken_link.py"
1551             try:
1552                 symlink.symlink_to("nonexistent.py")
1553             except OSError as e:
1554                 self.skipTest(f"Can't create symlinks: {e}")
1555             self.invokeBlack([str(workspace.resolve())])
1556
1557     def test_read_cache_line_lengths(self) -> None:
1558         mode = DEFAULT_MODE
1559         short_mode = replace(DEFAULT_MODE, line_length=1)
1560         with cache_dir() as workspace:
1561             path = (workspace / "file.py").resolve()
1562             path.touch()
1563             black.write_cache({}, [path], mode)
1564             one = black.read_cache(mode)
1565             self.assertIn(path, one)
1566             two = black.read_cache(short_mode)
1567             self.assertNotIn(path, two)
1568
1569     def test_tricky_unicode_symbols(self) -> None:
1570         source, expected = read_data("tricky_unicode_symbols")
1571         actual = fs(source)
1572         self.assertFormatEqual(expected, actual)
1573         black.assert_equivalent(source, actual)
1574         black.assert_stable(source, actual, DEFAULT_MODE)
1575
1576     def test_single_file_force_pyi(self) -> None:
1577         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1578         contents, expected = read_data("force_pyi")
1579         with cache_dir() as workspace:
1580             path = (workspace / "file.py").resolve()
1581             with open(path, "w") as fh:
1582                 fh.write(contents)
1583             self.invokeBlack([str(path), "--pyi"])
1584             with open(path, "r") as fh:
1585                 actual = fh.read()
1586             # verify cache with --pyi is separate
1587             pyi_cache = black.read_cache(pyi_mode)
1588             self.assertIn(path, pyi_cache)
1589             normal_cache = black.read_cache(DEFAULT_MODE)
1590             self.assertNotIn(path, normal_cache)
1591         self.assertFormatEqual(expected, actual)
1592         black.assert_equivalent(contents, actual)
1593         black.assert_stable(contents, actual, pyi_mode)
1594
1595     @event_loop()
1596     def test_multi_file_force_pyi(self) -> None:
1597         reg_mode = DEFAULT_MODE
1598         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1599         contents, expected = read_data("force_pyi")
1600         with cache_dir() as workspace:
1601             paths = [
1602                 (workspace / "file1.py").resolve(),
1603                 (workspace / "file2.py").resolve(),
1604             ]
1605             for path in paths:
1606                 with open(path, "w") as fh:
1607                     fh.write(contents)
1608             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1609             for path in paths:
1610                 with open(path, "r") as fh:
1611                     actual = fh.read()
1612                 self.assertEqual(actual, expected)
1613             # verify cache with --pyi is separate
1614             pyi_cache = black.read_cache(pyi_mode)
1615             normal_cache = black.read_cache(reg_mode)
1616             for path in paths:
1617                 self.assertIn(path, pyi_cache)
1618                 self.assertNotIn(path, normal_cache)
1619
1620     def test_pipe_force_pyi(self) -> None:
1621         source, expected = read_data("force_pyi")
1622         result = CliRunner().invoke(
1623             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1624         )
1625         self.assertEqual(result.exit_code, 0)
1626         actual = result.output
1627         self.assertFormatEqual(actual, expected)
1628
1629     def test_single_file_force_py36(self) -> None:
1630         reg_mode = DEFAULT_MODE
1631         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1632         source, expected = read_data("force_py36")
1633         with cache_dir() as workspace:
1634             path = (workspace / "file.py").resolve()
1635             with open(path, "w") as fh:
1636                 fh.write(source)
1637             self.invokeBlack([str(path), *PY36_ARGS])
1638             with open(path, "r") as fh:
1639                 actual = fh.read()
1640             # verify cache with --target-version is separate
1641             py36_cache = black.read_cache(py36_mode)
1642             self.assertIn(path, py36_cache)
1643             normal_cache = black.read_cache(reg_mode)
1644             self.assertNotIn(path, normal_cache)
1645         self.assertEqual(actual, expected)
1646
1647     @event_loop()
1648     def test_multi_file_force_py36(self) -> None:
1649         reg_mode = DEFAULT_MODE
1650         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1651         source, expected = read_data("force_py36")
1652         with cache_dir() as workspace:
1653             paths = [
1654                 (workspace / "file1.py").resolve(),
1655                 (workspace / "file2.py").resolve(),
1656             ]
1657             for path in paths:
1658                 with open(path, "w") as fh:
1659                     fh.write(source)
1660             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1661             for path in paths:
1662                 with open(path, "r") as fh:
1663                     actual = fh.read()
1664                 self.assertEqual(actual, expected)
1665             # verify cache with --target-version is separate
1666             pyi_cache = black.read_cache(py36_mode)
1667             normal_cache = black.read_cache(reg_mode)
1668             for path in paths:
1669                 self.assertIn(path, pyi_cache)
1670                 self.assertNotIn(path, normal_cache)
1671
1672     def test_collections(self) -> None:
1673         source, expected = read_data("collections")
1674         actual = fs(source)
1675         self.assertFormatEqual(expected, actual)
1676         black.assert_equivalent(source, actual)
1677         black.assert_stable(source, actual, DEFAULT_MODE)
1678
1679     def test_pipe_force_py36(self) -> None:
1680         source, expected = read_data("force_py36")
1681         result = CliRunner().invoke(
1682             black.main,
1683             ["-", "-q", "--target-version=py36"],
1684             input=BytesIO(source.encode("utf8")),
1685         )
1686         self.assertEqual(result.exit_code, 0)
1687         actual = result.output
1688         self.assertFormatEqual(actual, expected)
1689
1690     def test_include_exclude(self) -> None:
1691         path = THIS_DIR / "data" / "include_exclude_tests"
1692         include = re.compile(r"\.pyi?$")
1693         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1694         report = black.Report()
1695         gitignore = PathSpec.from_lines("gitwildmatch", [])
1696         sources: List[Path] = []
1697         expected = [
1698             Path(path / "b/dont_exclude/a.py"),
1699             Path(path / "b/dont_exclude/a.pyi"),
1700         ]
1701         this_abs = THIS_DIR.resolve()
1702         sources.extend(
1703             black.gen_python_files(
1704                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1705             )
1706         )
1707         self.assertEqual(sorted(expected), sorted(sources))
1708
1709     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1710     def test_exclude_for_issue_1572(self) -> None:
1711         # Exclude shouldn't touch files that were explicitly given to Black through the
1712         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1713         # https://github.com/psf/black/issues/1572
1714         path = THIS_DIR / "data" / "include_exclude_tests"
1715         include = ""
1716         exclude = r"/exclude/|a\.py"
1717         src = str(path / "b/exclude/a.py")
1718         report = black.Report()
1719         expected = [Path(path / "b/exclude/a.py")]
1720         sources = list(
1721             black.get_sources(
1722                 ctx=FakeContext(),
1723                 src=(src,),
1724                 quiet=True,
1725                 verbose=False,
1726                 include=include,
1727                 exclude=exclude,
1728                 force_exclude=None,
1729                 report=report,
1730             )
1731         )
1732         self.assertEqual(sorted(expected), sorted(sources))
1733
1734     def test_gitignore_exclude(self) -> None:
1735         path = THIS_DIR / "data" / "include_exclude_tests"
1736         include = re.compile(r"\.pyi?$")
1737         exclude = re.compile(r"")
1738         report = black.Report()
1739         gitignore = PathSpec.from_lines(
1740             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1741         )
1742         sources: List[Path] = []
1743         expected = [
1744             Path(path / "b/dont_exclude/a.py"),
1745             Path(path / "b/dont_exclude/a.pyi"),
1746         ]
1747         this_abs = THIS_DIR.resolve()
1748         sources.extend(
1749             black.gen_python_files(
1750                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1751             )
1752         )
1753         self.assertEqual(sorted(expected), sorted(sources))
1754
1755     def test_empty_include(self) -> None:
1756         path = THIS_DIR / "data" / "include_exclude_tests"
1757         report = black.Report()
1758         gitignore = PathSpec.from_lines("gitwildmatch", [])
1759         empty = re.compile(r"")
1760         sources: List[Path] = []
1761         expected = [
1762             Path(path / "b/exclude/a.pie"),
1763             Path(path / "b/exclude/a.py"),
1764             Path(path / "b/exclude/a.pyi"),
1765             Path(path / "b/dont_exclude/a.pie"),
1766             Path(path / "b/dont_exclude/a.py"),
1767             Path(path / "b/dont_exclude/a.pyi"),
1768             Path(path / "b/.definitely_exclude/a.pie"),
1769             Path(path / "b/.definitely_exclude/a.py"),
1770             Path(path / "b/.definitely_exclude/a.pyi"),
1771         ]
1772         this_abs = THIS_DIR.resolve()
1773         sources.extend(
1774             black.gen_python_files(
1775                 path.iterdir(),
1776                 this_abs,
1777                 empty,
1778                 re.compile(black.DEFAULT_EXCLUDES),
1779                 None,
1780                 report,
1781                 gitignore,
1782             )
1783         )
1784         self.assertEqual(sorted(expected), sorted(sources))
1785
1786     def test_empty_exclude(self) -> None:
1787         path = THIS_DIR / "data" / "include_exclude_tests"
1788         report = black.Report()
1789         gitignore = PathSpec.from_lines("gitwildmatch", [])
1790         empty = re.compile(r"")
1791         sources: List[Path] = []
1792         expected = [
1793             Path(path / "b/dont_exclude/a.py"),
1794             Path(path / "b/dont_exclude/a.pyi"),
1795             Path(path / "b/exclude/a.py"),
1796             Path(path / "b/exclude/a.pyi"),
1797             Path(path / "b/.definitely_exclude/a.py"),
1798             Path(path / "b/.definitely_exclude/a.pyi"),
1799         ]
1800         this_abs = THIS_DIR.resolve()
1801         sources.extend(
1802             black.gen_python_files(
1803                 path.iterdir(),
1804                 this_abs,
1805                 re.compile(black.DEFAULT_INCLUDES),
1806                 empty,
1807                 None,
1808                 report,
1809                 gitignore,
1810             )
1811         )
1812         self.assertEqual(sorted(expected), sorted(sources))
1813
1814     def test_invalid_include_exclude(self) -> None:
1815         for option in ["--include", "--exclude"]:
1816             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1817
1818     def test_preserves_line_endings(self) -> None:
1819         with TemporaryDirectory() as workspace:
1820             test_file = Path(workspace) / "test.py"
1821             for nl in ["\n", "\r\n"]:
1822                 contents = nl.join(["def f(  ):", "    pass"])
1823                 test_file.write_bytes(contents.encode())
1824                 ff(test_file, write_back=black.WriteBack.YES)
1825                 updated_contents: bytes = test_file.read_bytes()
1826                 self.assertIn(nl.encode(), updated_contents)
1827                 if nl == "\n":
1828                     self.assertNotIn(b"\r\n", updated_contents)
1829
1830     def test_preserves_line_endings_via_stdin(self) -> None:
1831         for nl in ["\n", "\r\n"]:
1832             contents = nl.join(["def f(  ):", "    pass"])
1833             runner = BlackRunner()
1834             result = runner.invoke(
1835                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1836             )
1837             self.assertEqual(result.exit_code, 0)
1838             output = runner.stdout_bytes
1839             self.assertIn(nl.encode("utf8"), output)
1840             if nl == "\n":
1841                 self.assertNotIn(b"\r\n", output)
1842
1843     def test_assert_equivalent_different_asts(self) -> None:
1844         with self.assertRaises(AssertionError):
1845             black.assert_equivalent("{}", "None")
1846
1847     def test_symlink_out_of_root_directory(self) -> None:
1848         path = MagicMock()
1849         root = THIS_DIR.resolve()
1850         child = MagicMock()
1851         include = re.compile(black.DEFAULT_INCLUDES)
1852         exclude = re.compile(black.DEFAULT_EXCLUDES)
1853         report = black.Report()
1854         gitignore = PathSpec.from_lines("gitwildmatch", [])
1855         # `child` should behave like a symlink which resolved path is clearly
1856         # outside of the `root` directory.
1857         path.iterdir.return_value = [child]
1858         child.resolve.return_value = Path("/a/b/c")
1859         child.as_posix.return_value = "/a/b/c"
1860         child.is_symlink.return_value = True
1861         try:
1862             list(
1863                 black.gen_python_files(
1864                     path.iterdir(), root, include, exclude, None, report, gitignore
1865                 )
1866             )
1867         except ValueError as ve:
1868             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1869         path.iterdir.assert_called_once()
1870         child.resolve.assert_called_once()
1871         child.is_symlink.assert_called_once()
1872         # `child` should behave like a strange file which resolved path is clearly
1873         # outside of the `root` directory.
1874         child.is_symlink.return_value = False
1875         with self.assertRaises(ValueError):
1876             list(
1877                 black.gen_python_files(
1878                     path.iterdir(), root, include, exclude, None, report, gitignore
1879                 )
1880             )
1881         path.iterdir.assert_called()
1882         self.assertEqual(path.iterdir.call_count, 2)
1883         child.resolve.assert_called()
1884         self.assertEqual(child.resolve.call_count, 2)
1885         child.is_symlink.assert_called()
1886         self.assertEqual(child.is_symlink.call_count, 2)
1887
1888     def test_shhh_click(self) -> None:
1889         try:
1890             from click import _unicodefun  # type: ignore
1891         except ModuleNotFoundError:
1892             self.skipTest("Incompatible Click version")
1893         if not hasattr(_unicodefun, "_verify_python3_env"):
1894             self.skipTest("Incompatible Click version")
1895         # First, let's see if Click is crashing with a preferred ASCII charset.
1896         with patch("locale.getpreferredencoding") as gpe:
1897             gpe.return_value = "ASCII"
1898             with self.assertRaises(RuntimeError):
1899                 _unicodefun._verify_python3_env()
1900         # Now, let's silence Click...
1901         black.patch_click()
1902         # ...and confirm it's silent.
1903         with patch("locale.getpreferredencoding") as gpe:
1904             gpe.return_value = "ASCII"
1905             try:
1906                 _unicodefun._verify_python3_env()
1907             except RuntimeError as re:
1908                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1909
1910     def test_root_logger_not_used_directly(self) -> None:
1911         def fail(*args: Any, **kwargs: Any) -> None:
1912             self.fail("Record created with root logger")
1913
1914         with patch.multiple(
1915             logging.root,
1916             debug=fail,
1917             info=fail,
1918             warning=fail,
1919             error=fail,
1920             critical=fail,
1921             log=fail,
1922         ):
1923             ff(THIS_FILE)
1924
1925     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1926     def test_blackd_main(self) -> None:
1927         with patch("blackd.web.run_app"):
1928             result = CliRunner().invoke(blackd.main, [])
1929             if result.exception is not None:
1930                 raise result.exception
1931             self.assertEqual(result.exit_code, 0)
1932
1933     def test_invalid_config_return_code(self) -> None:
1934         tmp_file = Path(black.dump_to_file())
1935         try:
1936             tmp_config = Path(black.dump_to_file())
1937             tmp_config.unlink()
1938             args = ["--config", str(tmp_config), str(tmp_file)]
1939             self.invokeBlack(args, exit_code=2, ignore_config=False)
1940         finally:
1941             tmp_file.unlink()
1942
1943     def test_parse_pyproject_toml(self) -> None:
1944         test_toml_file = THIS_DIR / "test.toml"
1945         config = black.parse_pyproject_toml(str(test_toml_file))
1946         self.assertEqual(config["verbose"], 1)
1947         self.assertEqual(config["check"], "no")
1948         self.assertEqual(config["diff"], "y")
1949         self.assertEqual(config["color"], True)
1950         self.assertEqual(config["line_length"], 79)
1951         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1952         self.assertEqual(config["exclude"], r"\.pyi?$")
1953         self.assertEqual(config["include"], r"\.py?$")
1954
1955     def test_read_pyproject_toml(self) -> None:
1956         test_toml_file = THIS_DIR / "test.toml"
1957         fake_ctx = FakeContext()
1958         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1959         config = fake_ctx.default_map
1960         self.assertEqual(config["verbose"], "1")
1961         self.assertEqual(config["check"], "no")
1962         self.assertEqual(config["diff"], "y")
1963         self.assertEqual(config["color"], "True")
1964         self.assertEqual(config["line_length"], "79")
1965         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1966         self.assertEqual(config["exclude"], r"\.pyi?$")
1967         self.assertEqual(config["include"], r"\.py?$")
1968
1969     def test_find_project_root(self) -> None:
1970         with TemporaryDirectory() as workspace:
1971             root = Path(workspace)
1972             test_dir = root / "test"
1973             test_dir.mkdir()
1974
1975             src_dir = root / "src"
1976             src_dir.mkdir()
1977
1978             root_pyproject = root / "pyproject.toml"
1979             root_pyproject.touch()
1980             src_pyproject = src_dir / "pyproject.toml"
1981             src_pyproject.touch()
1982             src_python = src_dir / "foo.py"
1983             src_python.touch()
1984
1985             self.assertEqual(
1986                 black.find_project_root((src_dir, test_dir)), root.resolve()
1987             )
1988             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1989             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1990
1991     def test_bpo_33660_workaround(self) -> None:
1992         if system() == "Windows":
1993             return
1994
1995         # https://bugs.python.org/issue33660
1996
1997         old_cwd = Path.cwd()
1998         try:
1999             root = Path("/")
2000             os.chdir(str(root))
2001             path = Path("workspace") / "project"
2002             report = black.Report(verbose=True)
2003             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
2004             self.assertEqual(normalized_path, "workspace/project")
2005         finally:
2006             os.chdir(str(old_cwd))
2007
2008
2009 class BlackDTestCase(AioHTTPTestCase):
2010     async def get_application(self) -> web.Application:
2011         return blackd.make_app()
2012
2013     # TODO: remove these decorators once the below is released
2014     # https://github.com/aio-libs/aiohttp/pull/3727
2015     @skip_if_exception("ClientOSError")
2016     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2017     @unittest_run_loop
2018     async def test_blackd_request_needs_formatting(self) -> None:
2019         response = await self.client.post("/", data=b"print('hello world')")
2020         self.assertEqual(response.status, 200)
2021         self.assertEqual(response.charset, "utf8")
2022         self.assertEqual(await response.read(), b'print("hello world")\n')
2023
2024     @skip_if_exception("ClientOSError")
2025     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2026     @unittest_run_loop
2027     async def test_blackd_request_no_change(self) -> None:
2028         response = await self.client.post("/", data=b'print("hello world")\n')
2029         self.assertEqual(response.status, 204)
2030         self.assertEqual(await response.read(), b"")
2031
2032     @skip_if_exception("ClientOSError")
2033     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2034     @unittest_run_loop
2035     async def test_blackd_request_syntax_error(self) -> None:
2036         response = await self.client.post("/", data=b"what even ( is")
2037         self.assertEqual(response.status, 400)
2038         content = await response.text()
2039         self.assertTrue(
2040             content.startswith("Cannot parse"),
2041             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
2042         )
2043
2044     @skip_if_exception("ClientOSError")
2045     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2046     @unittest_run_loop
2047     async def test_blackd_unsupported_version(self) -> None:
2048         response = await self.client.post(
2049             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
2050         )
2051         self.assertEqual(response.status, 501)
2052
2053     @skip_if_exception("ClientOSError")
2054     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2055     @unittest_run_loop
2056     async def test_blackd_supported_version(self) -> None:
2057         response = await self.client.post(
2058             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
2059         )
2060         self.assertEqual(response.status, 200)
2061
2062     @skip_if_exception("ClientOSError")
2063     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2064     @unittest_run_loop
2065     async def test_blackd_invalid_python_variant(self) -> None:
2066         async def check(header_value: str, expected_status: int = 400) -> None:
2067             response = await self.client.post(
2068                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2069             )
2070             self.assertEqual(response.status, expected_status)
2071
2072         await check("lol")
2073         await check("ruby3.5")
2074         await check("pyi3.6")
2075         await check("py1.5")
2076         await check("2.8")
2077         await check("py2.8")
2078         await check("3.0")
2079         await check("pypy3.0")
2080         await check("jython3.4")
2081
2082     @skip_if_exception("ClientOSError")
2083     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2084     @unittest_run_loop
2085     async def test_blackd_pyi(self) -> None:
2086         source, expected = read_data("stub.pyi")
2087         response = await self.client.post(
2088             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
2089         )
2090         self.assertEqual(response.status, 200)
2091         self.assertEqual(await response.text(), expected)
2092
2093     @skip_if_exception("ClientOSError")
2094     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2095     @unittest_run_loop
2096     async def test_blackd_diff(self) -> None:
2097         diff_header = re.compile(
2098             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2099         )
2100
2101         source, _ = read_data("blackd_diff.py")
2102         expected, _ = read_data("blackd_diff.diff")
2103
2104         response = await self.client.post(
2105             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2106         )
2107         self.assertEqual(response.status, 200)
2108
2109         actual = await response.text()
2110         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2111         self.assertEqual(actual, expected)
2112
2113     @skip_if_exception("ClientOSError")
2114     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2115     @unittest_run_loop
2116     async def test_blackd_python_variant(self) -> None:
2117         code = (
2118             "def f(\n"
2119             "    and_has_a_bunch_of,\n"
2120             "    very_long_arguments_too,\n"
2121             "    and_lots_of_them_as_well_lol,\n"
2122             "    **and_very_long_keyword_arguments\n"
2123             "):\n"
2124             "    pass\n"
2125         )
2126
2127         async def check(header_value: str, expected_status: int) -> None:
2128             response = await self.client.post(
2129                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2130             )
2131             self.assertEqual(
2132                 response.status, expected_status, msg=await response.text()
2133             )
2134
2135         await check("3.6", 200)
2136         await check("py3.6", 200)
2137         await check("3.6,3.7", 200)
2138         await check("3.6,py3.7", 200)
2139         await check("py36,py37", 200)
2140         await check("36", 200)
2141         await check("3.6.4", 200)
2142
2143         await check("2", 204)
2144         await check("2.7", 204)
2145         await check("py2.7", 204)
2146         await check("3.4", 204)
2147         await check("py3.4", 204)
2148         await check("py34,py36", 204)
2149         await check("34", 204)
2150
2151     @skip_if_exception("ClientOSError")
2152     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2153     @unittest_run_loop
2154     async def test_blackd_line_length(self) -> None:
2155         response = await self.client.post(
2156             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2157         )
2158         self.assertEqual(response.status, 200)
2159
2160     @skip_if_exception("ClientOSError")
2161     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2162     @unittest_run_loop
2163     async def test_blackd_invalid_line_length(self) -> None:
2164         response = await self.client.post(
2165             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2166         )
2167         self.assertEqual(response.status, 400)
2168
2169     @skip_if_exception("ClientOSError")
2170     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2171     @unittest_run_loop
2172     async def test_blackd_response_black_version_header(self) -> None:
2173         response = await self.client.post("/")
2174         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2175
2176
2177 with open(black.__file__, "r", encoding="utf-8") as _bf:
2178     black_source_lines = _bf.readlines()
2179
2180
2181 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2182     """Show function calls `from black/__init__.py` as they happen.
2183
2184     Register this with `sys.settrace()` in a test you're debugging.
2185     """
2186     if event != "call":
2187         return tracefunc
2188
2189     stack = len(inspect.stack()) - 19
2190     stack *= 2
2191     filename = frame.f_code.co_filename
2192     lineno = frame.f_lineno
2193     func_sig_lineno = lineno - 1
2194     funcname = black_source_lines[func_sig_lineno].strip()
2195     while funcname.startswith("@"):
2196         func_sig_lineno += 1
2197         funcname = black_source_lines[func_sig_lineno].strip()
2198     if "black/__init__.py" in filename:
2199         print(f"{' ' * stack}{lineno}:{funcname}")
2200     return tracefunc
2201
2202
2203 if __name__ == "__main__":
2204     unittest.main(module="test_black")