]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

fix 1631 and add test (#1641)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from dataclasses import replace
7 from functools import partial
8 import inspect
9 from io import BytesIO, TextIOWrapper
10 import os
11 from pathlib import Path
12 from platform import system
13 import regex as re
14 import sys
15 from tempfile import TemporaryDirectory
16 import types
17 from typing import (
18     Any,
19     BinaryIO,
20     Callable,
21     Dict,
22     Generator,
23     List,
24     Tuple,
25     Iterator,
26     TypeVar,
27 )
28 import unittest
29 from unittest.mock import patch, MagicMock
30
31 import click
32 from click import unstyle
33 from click.testing import CliRunner
34
35 import black
36 from black import Feature, TargetVersion
37
38 try:
39     import blackd
40     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
41     from aiohttp import web
42 except ImportError:
43     has_blackd_deps = False
44 else:
45     has_blackd_deps = True
46
47 from pathspec import PathSpec
48
49 # Import other test classes
50 from .test_primer import PrimerCLITests  # noqa: F401
51
52
53 DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
54 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
55 fs = partial(black.format_str, mode=DEFAULT_MODE)
56 THIS_FILE = Path(__file__)
57 THIS_DIR = THIS_FILE.parent
58 PROJECT_ROOT = THIS_DIR.parent
59 DETERMINISTIC_HEADER = "[Deterministic header]"
60 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
61 PY36_ARGS = [
62     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
63 ]
64 T = TypeVar("T")
65 R = TypeVar("R")
66
67
68 def dump_to_stderr(*output: str) -> str:
69     return "\n" + "\n".join(output) + "\n"
70
71
72 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
73     """read_data('test_name') -> 'input', 'output'"""
74     if not name.endswith((".py", ".pyi", ".out", ".diff")):
75         name += ".py"
76     _input: List[str] = []
77     _output: List[str] = []
78     base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
79     with open(base_dir / name, "r", encoding="utf8") as test:
80         lines = test.readlines()
81     result = _input
82     for line in lines:
83         line = line.replace(EMPTY_LINE, "")
84         if line.rstrip() == "# output":
85             result = _output
86             continue
87
88         result.append(line)
89     if _input and not _output:
90         # If there's no output marker, treat the entire file as already pre-formatted.
91         _output = _input[:]
92     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
93
94
95 @contextmanager
96 def cache_dir(exists: bool = True) -> Iterator[Path]:
97     with TemporaryDirectory() as workspace:
98         cache_dir = Path(workspace)
99         if not exists:
100             cache_dir = cache_dir / "new"
101         with patch("black.CACHE_DIR", cache_dir):
102             yield cache_dir
103
104
105 @contextmanager
106 def event_loop() -> Iterator[None]:
107     policy = asyncio.get_event_loop_policy()
108     loop = policy.new_event_loop()
109     asyncio.set_event_loop(loop)
110     try:
111         yield
112
113     finally:
114         loop.close()
115
116
117 @contextmanager
118 def skip_if_exception(e: str) -> Iterator[None]:
119     try:
120         yield
121     except Exception as exc:
122         if exc.__class__.__name__ == e:
123             unittest.skip(f"Encountered expected exception {exc}, skipping")
124         else:
125             raise
126
127
128 class FakeContext(click.Context):
129     """A fake click Context for when calling functions that need it."""
130
131     def __init__(self) -> None:
132         self.default_map: Dict[str, Any] = {}
133
134
135 class FakeParameter(click.Parameter):
136     """A fake click Parameter for when calling functions that need it."""
137
138     def __init__(self) -> None:
139         pass
140
141
142 class BlackRunner(CliRunner):
143     """Modify CliRunner so that stderr is not merged with stdout.
144
145     This is a hack that can be removed once we depend on Click 7.x"""
146
147     def __init__(self) -> None:
148         self.stderrbuf = BytesIO()
149         self.stdoutbuf = BytesIO()
150         self.stdout_bytes = b""
151         self.stderr_bytes = b""
152         super().__init__()
153
154     @contextmanager
155     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
156         with super().isolation(*args, **kwargs) as output:
157             try:
158                 hold_stderr = sys.stderr
159                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
160                 yield output
161             finally:
162                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
163                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
164                 sys.stderr = hold_stderr
165
166
167 class BlackTestCase(unittest.TestCase):
168     maxDiff = None
169     _diffThreshold = 2 ** 20
170
171     def assertFormatEqual(self, expected: str, actual: str) -> None:
172         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
173             bdv: black.DebugVisitor[Any]
174             black.out("Expected tree:", fg="green")
175             try:
176                 exp_node = black.lib2to3_parse(expected)
177                 bdv = black.DebugVisitor()
178                 list(bdv.visit(exp_node))
179             except Exception as ve:
180                 black.err(str(ve))
181             black.out("Actual tree:", fg="red")
182             try:
183                 exp_node = black.lib2to3_parse(actual)
184                 bdv = black.DebugVisitor()
185                 list(bdv.visit(exp_node))
186             except Exception as ve:
187                 black.err(str(ve))
188         self.assertMultiLineEqual(expected, actual)
189
190     def invokeBlack(
191         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
192     ) -> None:
193         runner = BlackRunner()
194         if ignore_config:
195             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
196         result = runner.invoke(black.main, args)
197         self.assertEqual(
198             result.exit_code,
199             exit_code,
200             msg=(
201                 f"Failed with args: {args}\n"
202                 f"stdout: {runner.stdout_bytes.decode()!r}\n"
203                 f"stderr: {runner.stderr_bytes.decode()!r}\n"
204                 f"exception: {result.exception}"
205             ),
206         )
207
208     @patch("black.dump_to_file", dump_to_stderr)
209     def checkSourceFile(self, name: str, mode: black.FileMode = DEFAULT_MODE) -> None:
210         path = THIS_DIR.parent / name
211         source, expected = read_data(str(path), data=False)
212         actual = fs(source, mode=mode)
213         self.assertFormatEqual(expected, actual)
214         black.assert_equivalent(source, actual)
215         black.assert_stable(source, actual, mode)
216         self.assertFalse(ff(path))
217
218     @patch("black.dump_to_file", dump_to_stderr)
219     def test_empty(self) -> None:
220         source = expected = ""
221         actual = fs(source)
222         self.assertFormatEqual(expected, actual)
223         black.assert_equivalent(source, actual)
224         black.assert_stable(source, actual, DEFAULT_MODE)
225
226     def test_empty_ff(self) -> None:
227         expected = ""
228         tmp_file = Path(black.dump_to_file())
229         try:
230             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
231             with open(tmp_file, encoding="utf8") as f:
232                 actual = f.read()
233         finally:
234             os.unlink(tmp_file)
235         self.assertFormatEqual(expected, actual)
236
237     def test_self(self) -> None:
238         self.checkSourceFile("tests/test_black.py")
239
240     def test_black(self) -> None:
241         self.checkSourceFile("src/black/__init__.py")
242
243     def test_pygram(self) -> None:
244         self.checkSourceFile("src/blib2to3/pygram.py")
245
246     def test_pytree(self) -> None:
247         self.checkSourceFile("src/blib2to3/pytree.py")
248
249     def test_conv(self) -> None:
250         self.checkSourceFile("src/blib2to3/pgen2/conv.py")
251
252     def test_driver(self) -> None:
253         self.checkSourceFile("src/blib2to3/pgen2/driver.py")
254
255     def test_grammar(self) -> None:
256         self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
257
258     def test_literals(self) -> None:
259         self.checkSourceFile("src/blib2to3/pgen2/literals.py")
260
261     def test_parse(self) -> None:
262         self.checkSourceFile("src/blib2to3/pgen2/parse.py")
263
264     def test_pgen(self) -> None:
265         self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
266
267     def test_tokenize(self) -> None:
268         self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
269
270     def test_token(self) -> None:
271         self.checkSourceFile("src/blib2to3/pgen2/token.py")
272
273     def test_setup(self) -> None:
274         self.checkSourceFile("setup.py")
275
276     def test_piping(self) -> None:
277         source, expected = read_data("src/black/__init__", data=False)
278         result = BlackRunner().invoke(
279             black.main,
280             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
281             input=BytesIO(source.encode("utf8")),
282         )
283         self.assertEqual(result.exit_code, 0)
284         self.assertFormatEqual(expected, result.output)
285         black.assert_equivalent(source, result.output)
286         black.assert_stable(source, result.output, DEFAULT_MODE)
287
288     def test_piping_diff(self) -> None:
289         diff_header = re.compile(
290             r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
291             r"\+\d\d\d\d"
292         )
293         source, _ = read_data("expression.py")
294         expected, _ = read_data("expression.diff")
295         config = THIS_DIR / "data" / "empty_pyproject.toml"
296         args = [
297             "-",
298             "--fast",
299             f"--line-length={black.DEFAULT_LINE_LENGTH}",
300             "--diff",
301             f"--config={config}",
302         ]
303         result = BlackRunner().invoke(
304             black.main, args, input=BytesIO(source.encode("utf8"))
305         )
306         self.assertEqual(result.exit_code, 0)
307         actual = diff_header.sub(DETERMINISTIC_HEADER, result.output)
308         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
309         self.assertEqual(expected, actual)
310
311     def test_piping_diff_with_color(self) -> None:
312         source, _ = read_data("expression.py")
313         config = THIS_DIR / "data" / "empty_pyproject.toml"
314         args = [
315             "-",
316             "--fast",
317             f"--line-length={black.DEFAULT_LINE_LENGTH}",
318             "--diff",
319             "--color",
320             f"--config={config}",
321         ]
322         result = BlackRunner().invoke(
323             black.main, args, input=BytesIO(source.encode("utf8"))
324         )
325         actual = result.output
326         # Again, the contents are checked in a different test, so only look for colors.
327         self.assertIn("\033[1;37m", actual)
328         self.assertIn("\033[36m", actual)
329         self.assertIn("\033[32m", actual)
330         self.assertIn("\033[31m", actual)
331         self.assertIn("\033[0m", actual)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_function(self) -> None:
335         source, expected = read_data("function")
336         actual = fs(source)
337         self.assertFormatEqual(expected, actual)
338         black.assert_equivalent(source, actual)
339         black.assert_stable(source, actual, DEFAULT_MODE)
340
341     @patch("black.dump_to_file", dump_to_stderr)
342     def test_function2(self) -> None:
343         source, expected = read_data("function2")
344         actual = fs(source)
345         self.assertFormatEqual(expected, actual)
346         black.assert_equivalent(source, actual)
347         black.assert_stable(source, actual, DEFAULT_MODE)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def _test_wip(self) -> None:
351         source, expected = read_data("wip")
352         sys.settrace(tracefunc)
353         mode = replace(
354             DEFAULT_MODE,
355             experimental_string_processing=False,
356             target_versions={black.TargetVersion.PY38},
357         )
358         actual = fs(source, mode=mode)
359         sys.settrace(None)
360         self.assertFormatEqual(expected, actual)
361         black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, black.FileMode())
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_function_trailing_comma(self) -> None:
366         source, expected = read_data("function_trailing_comma")
367         actual = fs(source)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, DEFAULT_MODE)
371
372     @unittest.expectedFailure
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_trailing_comma_optional_parens_stability1(self) -> None:
375         source, _expected = read_data("trailing_comma_optional_parens1")
376         actual = fs(source)
377         black.assert_stable(source, actual, DEFAULT_MODE)
378
379     @unittest.expectedFailure
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_trailing_comma_optional_parens_stability2(self) -> None:
382         source, _expected = read_data("trailing_comma_optional_parens2")
383         actual = fs(source)
384         black.assert_stable(source, actual, DEFAULT_MODE)
385
386     @unittest.expectedFailure
387     @patch("black.dump_to_file", dump_to_stderr)
388     def test_trailing_comma_optional_parens_stability3(self) -> None:
389         source, _expected = read_data("trailing_comma_optional_parens3")
390         actual = fs(source)
391         black.assert_stable(source, actual, DEFAULT_MODE)
392
393     @patch("black.dump_to_file", dump_to_stderr)
394     def test_expression(self) -> None:
395         source, expected = read_data("expression")
396         actual = fs(source)
397         self.assertFormatEqual(expected, actual)
398         black.assert_equivalent(source, actual)
399         black.assert_stable(source, actual, DEFAULT_MODE)
400
401     @patch("black.dump_to_file", dump_to_stderr)
402     def test_pep_572(self) -> None:
403         source, expected = read_data("pep_572")
404         actual = fs(source)
405         self.assertFormatEqual(expected, actual)
406         black.assert_stable(source, actual, DEFAULT_MODE)
407         if sys.version_info >= (3, 8):
408             black.assert_equivalent(source, actual)
409
410     def test_pep_572_version_detection(self) -> None:
411         source, _ = read_data("pep_572")
412         root = black.lib2to3_parse(source)
413         features = black.get_features_used(root)
414         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
415         versions = black.detect_target_versions(root)
416         self.assertIn(black.TargetVersion.PY38, versions)
417
418     def test_expression_ff(self) -> None:
419         source, expected = read_data("expression")
420         tmp_file = Path(black.dump_to_file(source))
421         try:
422             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
423             with open(tmp_file, encoding="utf8") as f:
424                 actual = f.read()
425         finally:
426             os.unlink(tmp_file)
427         self.assertFormatEqual(expected, actual)
428         with patch("black.dump_to_file", dump_to_stderr):
429             black.assert_equivalent(source, actual)
430             black.assert_stable(source, actual, DEFAULT_MODE)
431
432     def test_expression_diff(self) -> None:
433         source, _ = read_data("expression.py")
434         expected, _ = read_data("expression.diff")
435         tmp_file = Path(black.dump_to_file(source))
436         diff_header = re.compile(
437             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
438             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
439         )
440         try:
441             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
442             self.assertEqual(result.exit_code, 0)
443         finally:
444             os.unlink(tmp_file)
445         actual = result.output
446         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
447         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
448         if expected != actual:
449             dump = black.dump_to_file(actual)
450             msg = (
451                 "Expected diff isn't equal to the actual. If you made changes to"
452                 " expression.py and this is an anticipated difference, overwrite"
453                 f" tests/data/expression.diff with {dump}"
454             )
455             self.assertEqual(expected, actual, msg)
456
457     def test_expression_diff_with_color(self) -> None:
458         source, _ = read_data("expression.py")
459         expected, _ = read_data("expression.diff")
460         tmp_file = Path(black.dump_to_file(source))
461         try:
462             result = BlackRunner().invoke(
463                 black.main, ["--diff", "--color", str(tmp_file)]
464             )
465         finally:
466             os.unlink(tmp_file)
467         actual = result.output
468         # We check the contents of the diff in `test_expression_diff`. All
469         # we need to check here is that color codes exist in the result.
470         self.assertIn("\033[1;37m", actual)
471         self.assertIn("\033[36m", actual)
472         self.assertIn("\033[32m", actual)
473         self.assertIn("\033[31m", actual)
474         self.assertIn("\033[0m", actual)
475
476     @patch("black.dump_to_file", dump_to_stderr)
477     def test_fstring(self) -> None:
478         source, expected = read_data("fstring")
479         actual = fs(source)
480         self.assertFormatEqual(expected, actual)
481         black.assert_equivalent(source, actual)
482         black.assert_stable(source, actual, DEFAULT_MODE)
483
484     @patch("black.dump_to_file", dump_to_stderr)
485     def test_pep_570(self) -> None:
486         source, expected = read_data("pep_570")
487         actual = fs(source)
488         self.assertFormatEqual(expected, actual)
489         black.assert_stable(source, actual, DEFAULT_MODE)
490         if sys.version_info >= (3, 8):
491             black.assert_equivalent(source, actual)
492
493     def test_detect_pos_only_arguments(self) -> None:
494         source, _ = read_data("pep_570")
495         root = black.lib2to3_parse(source)
496         features = black.get_features_used(root)
497         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
498         versions = black.detect_target_versions(root)
499         self.assertIn(black.TargetVersion.PY38, versions)
500
501     @patch("black.dump_to_file", dump_to_stderr)
502     def test_string_quotes(self) -> None:
503         source, expected = read_data("string_quotes")
504         actual = fs(source)
505         self.assertFormatEqual(expected, actual)
506         black.assert_equivalent(source, actual)
507         black.assert_stable(source, actual, DEFAULT_MODE)
508         mode = replace(DEFAULT_MODE, string_normalization=False)
509         not_normalized = fs(source, mode=mode)
510         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
511         black.assert_equivalent(source, not_normalized)
512         black.assert_stable(source, not_normalized, mode=mode)
513
514     @patch("black.dump_to_file", dump_to_stderr)
515     def test_docstring(self) -> None:
516         source, expected = read_data("docstring")
517         actual = fs(source)
518         self.assertFormatEqual(expected, actual)
519         black.assert_equivalent(source, actual)
520         black.assert_stable(source, actual, DEFAULT_MODE)
521
522     @patch("black.dump_to_file", dump_to_stderr)
523     def test_docstring_no_string_normalization(self) -> None:
524         """Like test_docstring but with string normalization off."""
525         source, expected = read_data("docstring_no_string_normalization")
526         mode = replace(DEFAULT_MODE, string_normalization=False)
527         actual = fs(source, mode=mode)
528         self.assertFormatEqual(expected, actual)
529         black.assert_equivalent(source, actual)
530         black.assert_stable(source, actual, mode)
531
532     def test_long_strings(self) -> None:
533         """Tests for splitting long strings."""
534         source, expected = read_data("long_strings")
535         actual = fs(source)
536         self.assertFormatEqual(expected, actual)
537         black.assert_equivalent(source, actual)
538         black.assert_stable(source, actual, DEFAULT_MODE)
539
540     def test_long_strings_flag_disabled(self) -> None:
541         """Tests for turning off the string processing logic."""
542         source, expected = read_data("long_strings_flag_disabled")
543         mode = replace(DEFAULT_MODE, experimental_string_processing=False)
544         actual = fs(source, mode=mode)
545         self.assertFormatEqual(expected, actual)
546         black.assert_stable(expected, actual, mode)
547
548     @patch("black.dump_to_file", dump_to_stderr)
549     def test_long_strings__edge_case(self) -> None:
550         """Edge-case tests for splitting long strings."""
551         source, expected = read_data("long_strings__edge_case")
552         actual = fs(source)
553         self.assertFormatEqual(expected, actual)
554         black.assert_equivalent(source, actual)
555         black.assert_stable(source, actual, DEFAULT_MODE)
556
557     @patch("black.dump_to_file", dump_to_stderr)
558     def test_long_strings__regression(self) -> None:
559         """Regression tests for splitting long strings."""
560         source, expected = read_data("long_strings__regression")
561         actual = fs(source)
562         self.assertFormatEqual(expected, actual)
563         black.assert_equivalent(source, actual)
564         black.assert_stable(source, actual, DEFAULT_MODE)
565
566     @patch("black.dump_to_file", dump_to_stderr)
567     def test_slices(self) -> None:
568         source, expected = read_data("slices")
569         actual = fs(source)
570         self.assertFormatEqual(expected, actual)
571         black.assert_equivalent(source, actual)
572         black.assert_stable(source, actual, DEFAULT_MODE)
573
574     @patch("black.dump_to_file", dump_to_stderr)
575     def test_percent_precedence(self) -> None:
576         source, expected = read_data("percent_precedence")
577         actual = fs(source)
578         self.assertFormatEqual(expected, actual)
579         black.assert_equivalent(source, actual)
580         black.assert_stable(source, actual, DEFAULT_MODE)
581
582     @patch("black.dump_to_file", dump_to_stderr)
583     def test_comments(self) -> None:
584         source, expected = read_data("comments")
585         actual = fs(source)
586         self.assertFormatEqual(expected, actual)
587         black.assert_equivalent(source, actual)
588         black.assert_stable(source, actual, DEFAULT_MODE)
589
590     @patch("black.dump_to_file", dump_to_stderr)
591     def test_comments2(self) -> None:
592         source, expected = read_data("comments2")
593         actual = fs(source)
594         self.assertFormatEqual(expected, actual)
595         black.assert_equivalent(source, actual)
596         black.assert_stable(source, actual, DEFAULT_MODE)
597
598     @patch("black.dump_to_file", dump_to_stderr)
599     def test_comments3(self) -> None:
600         source, expected = read_data("comments3")
601         actual = fs(source)
602         self.assertFormatEqual(expected, actual)
603         black.assert_equivalent(source, actual)
604         black.assert_stable(source, actual, DEFAULT_MODE)
605
606     @patch("black.dump_to_file", dump_to_stderr)
607     def test_comments4(self) -> None:
608         source, expected = read_data("comments4")
609         actual = fs(source)
610         self.assertFormatEqual(expected, actual)
611         black.assert_equivalent(source, actual)
612         black.assert_stable(source, actual, DEFAULT_MODE)
613
614     @patch("black.dump_to_file", dump_to_stderr)
615     def test_comments5(self) -> None:
616         source, expected = read_data("comments5")
617         actual = fs(source)
618         self.assertFormatEqual(expected, actual)
619         black.assert_equivalent(source, actual)
620         black.assert_stable(source, actual, DEFAULT_MODE)
621
622     @patch("black.dump_to_file", dump_to_stderr)
623     def test_comments6(self) -> None:
624         source, expected = read_data("comments6")
625         actual = fs(source)
626         self.assertFormatEqual(expected, actual)
627         black.assert_equivalent(source, actual)
628         black.assert_stable(source, actual, DEFAULT_MODE)
629
630     @patch("black.dump_to_file", dump_to_stderr)
631     def test_comments7(self) -> None:
632         source, expected = read_data("comments7")
633         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
634         actual = fs(source, mode=mode)
635         self.assertFormatEqual(expected, actual)
636         black.assert_equivalent(source, actual)
637         black.assert_stable(source, actual, DEFAULT_MODE)
638
639     @patch("black.dump_to_file", dump_to_stderr)
640     def test_comment_after_escaped_newline(self) -> None:
641         source, expected = read_data("comment_after_escaped_newline")
642         actual = fs(source)
643         self.assertFormatEqual(expected, actual)
644         black.assert_equivalent(source, actual)
645         black.assert_stable(source, actual, DEFAULT_MODE)
646
647     @patch("black.dump_to_file", dump_to_stderr)
648     def test_cantfit(self) -> None:
649         source, expected = read_data("cantfit")
650         actual = fs(source)
651         self.assertFormatEqual(expected, actual)
652         black.assert_equivalent(source, actual)
653         black.assert_stable(source, actual, DEFAULT_MODE)
654
655     @patch("black.dump_to_file", dump_to_stderr)
656     def test_import_spacing(self) -> None:
657         source, expected = read_data("import_spacing")
658         actual = fs(source)
659         self.assertFormatEqual(expected, actual)
660         black.assert_equivalent(source, actual)
661         black.assert_stable(source, actual, DEFAULT_MODE)
662
663     @patch("black.dump_to_file", dump_to_stderr)
664     def test_composition(self) -> None:
665         source, expected = read_data("composition")
666         actual = fs(source)
667         self.assertFormatEqual(expected, actual)
668         black.assert_equivalent(source, actual)
669         black.assert_stable(source, actual, DEFAULT_MODE)
670
671     @patch("black.dump_to_file", dump_to_stderr)
672     def test_composition_no_trailing_comma(self) -> None:
673         source, expected = read_data("composition_no_trailing_comma")
674         mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
675         actual = fs(source, mode=mode)
676         self.assertFormatEqual(expected, actual)
677         black.assert_equivalent(source, actual)
678         black.assert_stable(source, actual, DEFAULT_MODE)
679
680     @patch("black.dump_to_file", dump_to_stderr)
681     def test_empty_lines(self) -> None:
682         source, expected = read_data("empty_lines")
683         actual = fs(source)
684         self.assertFormatEqual(expected, actual)
685         black.assert_equivalent(source, actual)
686         black.assert_stable(source, actual, DEFAULT_MODE)
687
688     @patch("black.dump_to_file", dump_to_stderr)
689     def test_remove_parens(self) -> None:
690         source, expected = read_data("remove_parens")
691         actual = fs(source)
692         self.assertFormatEqual(expected, actual)
693         black.assert_equivalent(source, actual)
694         black.assert_stable(source, actual, DEFAULT_MODE)
695
696     @patch("black.dump_to_file", dump_to_stderr)
697     def test_string_prefixes(self) -> None:
698         source, expected = read_data("string_prefixes")
699         actual = fs(source)
700         self.assertFormatEqual(expected, actual)
701         black.assert_equivalent(source, actual)
702         black.assert_stable(source, actual, DEFAULT_MODE)
703
704     @patch("black.dump_to_file", dump_to_stderr)
705     def test_numeric_literals(self) -> None:
706         source, expected = read_data("numeric_literals")
707         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
708         actual = fs(source, mode=mode)
709         self.assertFormatEqual(expected, actual)
710         black.assert_equivalent(source, actual)
711         black.assert_stable(source, actual, mode)
712
713     @patch("black.dump_to_file", dump_to_stderr)
714     def test_numeric_literals_ignoring_underscores(self) -> None:
715         source, expected = read_data("numeric_literals_skip_underscores")
716         mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
717         actual = fs(source, mode=mode)
718         self.assertFormatEqual(expected, actual)
719         black.assert_equivalent(source, actual)
720         black.assert_stable(source, actual, mode)
721
722     @patch("black.dump_to_file", dump_to_stderr)
723     def test_numeric_literals_py2(self) -> None:
724         source, expected = read_data("numeric_literals_py2")
725         actual = fs(source)
726         self.assertFormatEqual(expected, actual)
727         black.assert_stable(source, actual, DEFAULT_MODE)
728
729     @patch("black.dump_to_file", dump_to_stderr)
730     def test_python2(self) -> None:
731         source, expected = read_data("python2")
732         actual = fs(source)
733         self.assertFormatEqual(expected, actual)
734         black.assert_equivalent(source, actual)
735         black.assert_stable(source, actual, DEFAULT_MODE)
736
737     @patch("black.dump_to_file", dump_to_stderr)
738     def test_python2_print_function(self) -> None:
739         source, expected = read_data("python2_print_function")
740         mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
741         actual = fs(source, mode=mode)
742         self.assertFormatEqual(expected, actual)
743         black.assert_equivalent(source, actual)
744         black.assert_stable(source, actual, mode)
745
746     @patch("black.dump_to_file", dump_to_stderr)
747     def test_python2_unicode_literals(self) -> None:
748         source, expected = read_data("python2_unicode_literals")
749         actual = fs(source)
750         self.assertFormatEqual(expected, actual)
751         black.assert_equivalent(source, actual)
752         black.assert_stable(source, actual, DEFAULT_MODE)
753
754     @patch("black.dump_to_file", dump_to_stderr)
755     def test_stub(self) -> None:
756         mode = replace(DEFAULT_MODE, is_pyi=True)
757         source, expected = read_data("stub.pyi")
758         actual = fs(source, mode=mode)
759         self.assertFormatEqual(expected, actual)
760         black.assert_stable(source, actual, mode)
761
762     @patch("black.dump_to_file", dump_to_stderr)
763     def test_async_as_identifier(self) -> None:
764         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
765         source, expected = read_data("async_as_identifier")
766         actual = fs(source)
767         self.assertFormatEqual(expected, actual)
768         major, minor = sys.version_info[:2]
769         if major < 3 or (major <= 3 and minor < 7):
770             black.assert_equivalent(source, actual)
771         black.assert_stable(source, actual, DEFAULT_MODE)
772         # ensure black can parse this when the target is 3.6
773         self.invokeBlack([str(source_path), "--target-version", "py36"])
774         # but not on 3.7, because async/await is no longer an identifier
775         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
776
777     @patch("black.dump_to_file", dump_to_stderr)
778     def test_python37(self) -> None:
779         source_path = (THIS_DIR / "data" / "python37.py").resolve()
780         source, expected = read_data("python37")
781         actual = fs(source)
782         self.assertFormatEqual(expected, actual)
783         major, minor = sys.version_info[:2]
784         if major > 3 or (major == 3 and minor >= 7):
785             black.assert_equivalent(source, actual)
786         black.assert_stable(source, actual, DEFAULT_MODE)
787         # ensure black can parse this when the target is 3.7
788         self.invokeBlack([str(source_path), "--target-version", "py37"])
789         # but not on 3.6, because we use async as a reserved keyword
790         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
791
792     @patch("black.dump_to_file", dump_to_stderr)
793     def test_python38(self) -> None:
794         source, expected = read_data("python38")
795         actual = fs(source)
796         self.assertFormatEqual(expected, actual)
797         major, minor = sys.version_info[:2]
798         if major > 3 or (major == 3 and minor >= 8):
799             black.assert_equivalent(source, actual)
800         black.assert_stable(source, actual, DEFAULT_MODE)
801
802     @patch("black.dump_to_file", dump_to_stderr)
803     def test_fmtonoff(self) -> None:
804         source, expected = read_data("fmtonoff")
805         actual = fs(source)
806         self.assertFormatEqual(expected, actual)
807         black.assert_equivalent(source, actual)
808         black.assert_stable(source, actual, DEFAULT_MODE)
809
810     @patch("black.dump_to_file", dump_to_stderr)
811     def test_fmtonoff2(self) -> None:
812         source, expected = read_data("fmtonoff2")
813         actual = fs(source)
814         self.assertFormatEqual(expected, actual)
815         black.assert_equivalent(source, actual)
816         black.assert_stable(source, actual, DEFAULT_MODE)
817
818     @patch("black.dump_to_file", dump_to_stderr)
819     def test_fmtonoff3(self) -> None:
820         source, expected = read_data("fmtonoff3")
821         actual = fs(source)
822         self.assertFormatEqual(expected, actual)
823         black.assert_equivalent(source, actual)
824         black.assert_stable(source, actual, DEFAULT_MODE)
825
826     @patch("black.dump_to_file", dump_to_stderr)
827     def test_fmtonoff4(self) -> None:
828         source, expected = read_data("fmtonoff4")
829         actual = fs(source)
830         self.assertFormatEqual(expected, actual)
831         black.assert_equivalent(source, actual)
832         black.assert_stable(source, actual, DEFAULT_MODE)
833
834     @patch("black.dump_to_file", dump_to_stderr)
835     def test_remove_empty_parentheses_after_class(self) -> None:
836         source, expected = read_data("class_blank_parentheses")
837         actual = fs(source)
838         self.assertFormatEqual(expected, actual)
839         black.assert_equivalent(source, actual)
840         black.assert_stable(source, actual, DEFAULT_MODE)
841
842     @patch("black.dump_to_file", dump_to_stderr)
843     def test_new_line_between_class_and_code(self) -> None:
844         source, expected = read_data("class_methods_new_line")
845         actual = fs(source)
846         self.assertFormatEqual(expected, actual)
847         black.assert_equivalent(source, actual)
848         black.assert_stable(source, actual, DEFAULT_MODE)
849
850     @patch("black.dump_to_file", dump_to_stderr)
851     def test_bracket_match(self) -> None:
852         source, expected = read_data("bracketmatch")
853         actual = fs(source)
854         self.assertFormatEqual(expected, actual)
855         black.assert_equivalent(source, actual)
856         black.assert_stable(source, actual, DEFAULT_MODE)
857
858     @patch("black.dump_to_file", dump_to_stderr)
859     def test_tuple_assign(self) -> None:
860         source, expected = read_data("tupleassign")
861         actual = fs(source)
862         self.assertFormatEqual(expected, actual)
863         black.assert_equivalent(source, actual)
864         black.assert_stable(source, actual, DEFAULT_MODE)
865
866     @patch("black.dump_to_file", dump_to_stderr)
867     def test_beginning_backslash(self) -> None:
868         source, expected = read_data("beginning_backslash")
869         actual = fs(source)
870         self.assertFormatEqual(expected, actual)
871         black.assert_equivalent(source, actual)
872         black.assert_stable(source, actual, DEFAULT_MODE)
873
874     def test_tab_comment_indentation(self) -> None:
875         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
876         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
877         self.assertFormatEqual(contents_spc, fs(contents_spc))
878         self.assertFormatEqual(contents_spc, fs(contents_tab))
879
880         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
881         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
882         self.assertFormatEqual(contents_spc, fs(contents_spc))
883         self.assertFormatEqual(contents_spc, fs(contents_tab))
884
885         # mixed tabs and spaces (valid Python 2 code)
886         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
887         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
888         self.assertFormatEqual(contents_spc, fs(contents_spc))
889         self.assertFormatEqual(contents_spc, fs(contents_tab))
890
891         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
892         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
893         self.assertFormatEqual(contents_spc, fs(contents_spc))
894         self.assertFormatEqual(contents_spc, fs(contents_tab))
895
896     def test_report_verbose(self) -> None:
897         report = black.Report(verbose=True)
898         out_lines = []
899         err_lines = []
900
901         def out(msg: str, **kwargs: Any) -> None:
902             out_lines.append(msg)
903
904         def err(msg: str, **kwargs: Any) -> None:
905             err_lines.append(msg)
906
907         with patch("black.out", out), patch("black.err", err):
908             report.done(Path("f1"), black.Changed.NO)
909             self.assertEqual(len(out_lines), 1)
910             self.assertEqual(len(err_lines), 0)
911             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
912             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
913             self.assertEqual(report.return_code, 0)
914             report.done(Path("f2"), black.Changed.YES)
915             self.assertEqual(len(out_lines), 2)
916             self.assertEqual(len(err_lines), 0)
917             self.assertEqual(out_lines[-1], "reformatted f2")
918             self.assertEqual(
919                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
920             )
921             report.done(Path("f3"), black.Changed.CACHED)
922             self.assertEqual(len(out_lines), 3)
923             self.assertEqual(len(err_lines), 0)
924             self.assertEqual(
925                 out_lines[-1], "f3 wasn't modified on disk since last run."
926             )
927             self.assertEqual(
928                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
929             )
930             self.assertEqual(report.return_code, 0)
931             report.check = True
932             self.assertEqual(report.return_code, 1)
933             report.check = False
934             report.failed(Path("e1"), "boom")
935             self.assertEqual(len(out_lines), 3)
936             self.assertEqual(len(err_lines), 1)
937             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
938             self.assertEqual(
939                 unstyle(str(report)),
940                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
941                 " reformat.",
942             )
943             self.assertEqual(report.return_code, 123)
944             report.done(Path("f3"), black.Changed.YES)
945             self.assertEqual(len(out_lines), 4)
946             self.assertEqual(len(err_lines), 1)
947             self.assertEqual(out_lines[-1], "reformatted f3")
948             self.assertEqual(
949                 unstyle(str(report)),
950                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
951                 " reformat.",
952             )
953             self.assertEqual(report.return_code, 123)
954             report.failed(Path("e2"), "boom")
955             self.assertEqual(len(out_lines), 4)
956             self.assertEqual(len(err_lines), 2)
957             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
958             self.assertEqual(
959                 unstyle(str(report)),
960                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
961                 " reformat.",
962             )
963             self.assertEqual(report.return_code, 123)
964             report.path_ignored(Path("wat"), "no match")
965             self.assertEqual(len(out_lines), 5)
966             self.assertEqual(len(err_lines), 2)
967             self.assertEqual(out_lines[-1], "wat ignored: no match")
968             self.assertEqual(
969                 unstyle(str(report)),
970                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
971                 " reformat.",
972             )
973             self.assertEqual(report.return_code, 123)
974             report.done(Path("f4"), black.Changed.NO)
975             self.assertEqual(len(out_lines), 6)
976             self.assertEqual(len(err_lines), 2)
977             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
978             self.assertEqual(
979                 unstyle(str(report)),
980                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
981                 " reformat.",
982             )
983             self.assertEqual(report.return_code, 123)
984             report.check = True
985             self.assertEqual(
986                 unstyle(str(report)),
987                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
988                 " would fail to reformat.",
989             )
990             report.check = False
991             report.diff = True
992             self.assertEqual(
993                 unstyle(str(report)),
994                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
995                 " would fail to reformat.",
996             )
997
998     def test_report_quiet(self) -> None:
999         report = black.Report(quiet=True)
1000         out_lines = []
1001         err_lines = []
1002
1003         def out(msg: str, **kwargs: Any) -> None:
1004             out_lines.append(msg)
1005
1006         def err(msg: str, **kwargs: Any) -> None:
1007             err_lines.append(msg)
1008
1009         with patch("black.out", out), patch("black.err", err):
1010             report.done(Path("f1"), black.Changed.NO)
1011             self.assertEqual(len(out_lines), 0)
1012             self.assertEqual(len(err_lines), 0)
1013             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1014             self.assertEqual(report.return_code, 0)
1015             report.done(Path("f2"), black.Changed.YES)
1016             self.assertEqual(len(out_lines), 0)
1017             self.assertEqual(len(err_lines), 0)
1018             self.assertEqual(
1019                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1020             )
1021             report.done(Path("f3"), black.Changed.CACHED)
1022             self.assertEqual(len(out_lines), 0)
1023             self.assertEqual(len(err_lines), 0)
1024             self.assertEqual(
1025                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1026             )
1027             self.assertEqual(report.return_code, 0)
1028             report.check = True
1029             self.assertEqual(report.return_code, 1)
1030             report.check = False
1031             report.failed(Path("e1"), "boom")
1032             self.assertEqual(len(out_lines), 0)
1033             self.assertEqual(len(err_lines), 1)
1034             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1035             self.assertEqual(
1036                 unstyle(str(report)),
1037                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1038                 " reformat.",
1039             )
1040             self.assertEqual(report.return_code, 123)
1041             report.done(Path("f3"), black.Changed.YES)
1042             self.assertEqual(len(out_lines), 0)
1043             self.assertEqual(len(err_lines), 1)
1044             self.assertEqual(
1045                 unstyle(str(report)),
1046                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1047                 " reformat.",
1048             )
1049             self.assertEqual(report.return_code, 123)
1050             report.failed(Path("e2"), "boom")
1051             self.assertEqual(len(out_lines), 0)
1052             self.assertEqual(len(err_lines), 2)
1053             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1054             self.assertEqual(
1055                 unstyle(str(report)),
1056                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1057                 " reformat.",
1058             )
1059             self.assertEqual(report.return_code, 123)
1060             report.path_ignored(Path("wat"), "no match")
1061             self.assertEqual(len(out_lines), 0)
1062             self.assertEqual(len(err_lines), 2)
1063             self.assertEqual(
1064                 unstyle(str(report)),
1065                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1066                 " reformat.",
1067             )
1068             self.assertEqual(report.return_code, 123)
1069             report.done(Path("f4"), black.Changed.NO)
1070             self.assertEqual(len(out_lines), 0)
1071             self.assertEqual(len(err_lines), 2)
1072             self.assertEqual(
1073                 unstyle(str(report)),
1074                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1075                 " reformat.",
1076             )
1077             self.assertEqual(report.return_code, 123)
1078             report.check = True
1079             self.assertEqual(
1080                 unstyle(str(report)),
1081                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1082                 " would fail to reformat.",
1083             )
1084             report.check = False
1085             report.diff = True
1086             self.assertEqual(
1087                 unstyle(str(report)),
1088                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1089                 " would fail to reformat.",
1090             )
1091
1092     def test_report_normal(self) -> None:
1093         report = black.Report()
1094         out_lines = []
1095         err_lines = []
1096
1097         def out(msg: str, **kwargs: Any) -> None:
1098             out_lines.append(msg)
1099
1100         def err(msg: str, **kwargs: Any) -> None:
1101             err_lines.append(msg)
1102
1103         with patch("black.out", out), patch("black.err", err):
1104             report.done(Path("f1"), black.Changed.NO)
1105             self.assertEqual(len(out_lines), 0)
1106             self.assertEqual(len(err_lines), 0)
1107             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
1108             self.assertEqual(report.return_code, 0)
1109             report.done(Path("f2"), black.Changed.YES)
1110             self.assertEqual(len(out_lines), 1)
1111             self.assertEqual(len(err_lines), 0)
1112             self.assertEqual(out_lines[-1], "reformatted f2")
1113             self.assertEqual(
1114                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
1115             )
1116             report.done(Path("f3"), black.Changed.CACHED)
1117             self.assertEqual(len(out_lines), 1)
1118             self.assertEqual(len(err_lines), 0)
1119             self.assertEqual(out_lines[-1], "reformatted f2")
1120             self.assertEqual(
1121                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
1122             )
1123             self.assertEqual(report.return_code, 0)
1124             report.check = True
1125             self.assertEqual(report.return_code, 1)
1126             report.check = False
1127             report.failed(Path("e1"), "boom")
1128             self.assertEqual(len(out_lines), 1)
1129             self.assertEqual(len(err_lines), 1)
1130             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
1131             self.assertEqual(
1132                 unstyle(str(report)),
1133                 "1 file reformatted, 2 files left unchanged, 1 file failed to"
1134                 " reformat.",
1135             )
1136             self.assertEqual(report.return_code, 123)
1137             report.done(Path("f3"), black.Changed.YES)
1138             self.assertEqual(len(out_lines), 2)
1139             self.assertEqual(len(err_lines), 1)
1140             self.assertEqual(out_lines[-1], "reformatted f3")
1141             self.assertEqual(
1142                 unstyle(str(report)),
1143                 "2 files reformatted, 2 files left unchanged, 1 file failed to"
1144                 " reformat.",
1145             )
1146             self.assertEqual(report.return_code, 123)
1147             report.failed(Path("e2"), "boom")
1148             self.assertEqual(len(out_lines), 2)
1149             self.assertEqual(len(err_lines), 2)
1150             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
1151             self.assertEqual(
1152                 unstyle(str(report)),
1153                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1154                 " reformat.",
1155             )
1156             self.assertEqual(report.return_code, 123)
1157             report.path_ignored(Path("wat"), "no match")
1158             self.assertEqual(len(out_lines), 2)
1159             self.assertEqual(len(err_lines), 2)
1160             self.assertEqual(
1161                 unstyle(str(report)),
1162                 "2 files reformatted, 2 files left unchanged, 2 files failed to"
1163                 " reformat.",
1164             )
1165             self.assertEqual(report.return_code, 123)
1166             report.done(Path("f4"), black.Changed.NO)
1167             self.assertEqual(len(out_lines), 2)
1168             self.assertEqual(len(err_lines), 2)
1169             self.assertEqual(
1170                 unstyle(str(report)),
1171                 "2 files reformatted, 3 files left unchanged, 2 files failed to"
1172                 " reformat.",
1173             )
1174             self.assertEqual(report.return_code, 123)
1175             report.check = True
1176             self.assertEqual(
1177                 unstyle(str(report)),
1178                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1179                 " would fail to reformat.",
1180             )
1181             report.check = False
1182             report.diff = True
1183             self.assertEqual(
1184                 unstyle(str(report)),
1185                 "2 files would be reformatted, 3 files would be left unchanged, 2 files"
1186                 " would fail to reformat.",
1187             )
1188
1189     def test_lib2to3_parse(self) -> None:
1190         with self.assertRaises(black.InvalidInput):
1191             black.lib2to3_parse("invalid syntax")
1192
1193         straddling = "x + y"
1194         black.lib2to3_parse(straddling)
1195         black.lib2to3_parse(straddling, {TargetVersion.PY27})
1196         black.lib2to3_parse(straddling, {TargetVersion.PY36})
1197         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
1198
1199         py2_only = "print x"
1200         black.lib2to3_parse(py2_only)
1201         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
1202         with self.assertRaises(black.InvalidInput):
1203             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
1204         with self.assertRaises(black.InvalidInput):
1205             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
1206
1207         py3_only = "exec(x, end=y)"
1208         black.lib2to3_parse(py3_only)
1209         with self.assertRaises(black.InvalidInput):
1210             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
1211         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
1212         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
1213
1214     def test_get_features_used(self) -> None:
1215         node = black.lib2to3_parse("def f(*, arg): ...\n")
1216         self.assertEqual(black.get_features_used(node), set())
1217         node = black.lib2to3_parse("def f(*, arg,): ...\n")
1218         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
1219         node = black.lib2to3_parse("f(*arg,)\n")
1220         self.assertEqual(
1221             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
1222         )
1223         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
1224         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
1225         node = black.lib2to3_parse("123_456\n")
1226         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
1227         node = black.lib2to3_parse("123456\n")
1228         self.assertEqual(black.get_features_used(node), set())
1229         source, expected = read_data("function")
1230         node = black.lib2to3_parse(source)
1231         expected_features = {
1232             Feature.TRAILING_COMMA_IN_CALL,
1233             Feature.TRAILING_COMMA_IN_DEF,
1234             Feature.F_STRINGS,
1235         }
1236         self.assertEqual(black.get_features_used(node), expected_features)
1237         node = black.lib2to3_parse(expected)
1238         self.assertEqual(black.get_features_used(node), expected_features)
1239         source, expected = read_data("expression")
1240         node = black.lib2to3_parse(source)
1241         self.assertEqual(black.get_features_used(node), set())
1242         node = black.lib2to3_parse(expected)
1243         self.assertEqual(black.get_features_used(node), set())
1244
1245     def test_get_future_imports(self) -> None:
1246         node = black.lib2to3_parse("\n")
1247         self.assertEqual(set(), black.get_future_imports(node))
1248         node = black.lib2to3_parse("from __future__ import black\n")
1249         self.assertEqual({"black"}, black.get_future_imports(node))
1250         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1251         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1252         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1253         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1254         node = black.lib2to3_parse(
1255             "from __future__ import multiple\nfrom __future__ import imports\n"
1256         )
1257         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1258         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1259         self.assertEqual({"black"}, black.get_future_imports(node))
1260         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1261         self.assertEqual({"black"}, black.get_future_imports(node))
1262         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1263         self.assertEqual(set(), black.get_future_imports(node))
1264         node = black.lib2to3_parse("from some.module import black\n")
1265         self.assertEqual(set(), black.get_future_imports(node))
1266         node = black.lib2to3_parse(
1267             "from __future__ import unicode_literals as _unicode_literals"
1268         )
1269         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1270         node = black.lib2to3_parse(
1271             "from __future__ import unicode_literals as _lol, print"
1272         )
1273         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1274
1275     def test_debug_visitor(self) -> None:
1276         source, _ = read_data("debug_visitor.py")
1277         expected, _ = read_data("debug_visitor.out")
1278         out_lines = []
1279         err_lines = []
1280
1281         def out(msg: str, **kwargs: Any) -> None:
1282             out_lines.append(msg)
1283
1284         def err(msg: str, **kwargs: Any) -> None:
1285             err_lines.append(msg)
1286
1287         with patch("black.out", out), patch("black.err", err):
1288             black.DebugVisitor.show(source)
1289         actual = "\n".join(out_lines) + "\n"
1290         log_name = ""
1291         if expected != actual:
1292             log_name = black.dump_to_file(*out_lines)
1293         self.assertEqual(
1294             expected,
1295             actual,
1296             f"AST print out is different. Actual version dumped to {log_name}",
1297         )
1298
1299     def test_format_file_contents(self) -> None:
1300         empty = ""
1301         mode = DEFAULT_MODE
1302         with self.assertRaises(black.NothingChanged):
1303             black.format_file_contents(empty, mode=mode, fast=False)
1304         just_nl = "\n"
1305         with self.assertRaises(black.NothingChanged):
1306             black.format_file_contents(just_nl, mode=mode, fast=False)
1307         same = "j = [1, 2, 3]\n"
1308         with self.assertRaises(black.NothingChanged):
1309             black.format_file_contents(same, mode=mode, fast=False)
1310         different = "j = [1,2,3]"
1311         expected = same
1312         actual = black.format_file_contents(different, mode=mode, fast=False)
1313         self.assertEqual(expected, actual)
1314         invalid = "return if you can"
1315         with self.assertRaises(black.InvalidInput) as e:
1316             black.format_file_contents(invalid, mode=mode, fast=False)
1317         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1318
1319     def test_endmarker(self) -> None:
1320         n = black.lib2to3_parse("\n")
1321         self.assertEqual(n.type, black.syms.file_input)
1322         self.assertEqual(len(n.children), 1)
1323         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1324
1325     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1326     def test_assertFormatEqual(self) -> None:
1327         out_lines = []
1328         err_lines = []
1329
1330         def out(msg: str, **kwargs: Any) -> None:
1331             out_lines.append(msg)
1332
1333         def err(msg: str, **kwargs: Any) -> None:
1334             err_lines.append(msg)
1335
1336         with patch("black.out", out), patch("black.err", err):
1337             with self.assertRaises(AssertionError):
1338                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1339
1340         out_str = "".join(out_lines)
1341         self.assertTrue("Expected tree:" in out_str)
1342         self.assertTrue("Actual tree:" in out_str)
1343         self.assertEqual("".join(err_lines), "")
1344
1345     def test_cache_broken_file(self) -> None:
1346         mode = DEFAULT_MODE
1347         with cache_dir() as workspace:
1348             cache_file = black.get_cache_file(mode)
1349             with cache_file.open("w") as fobj:
1350                 fobj.write("this is not a pickle")
1351             self.assertEqual(black.read_cache(mode), {})
1352             src = (workspace / "test.py").resolve()
1353             with src.open("w") as fobj:
1354                 fobj.write("print('hello')")
1355             self.invokeBlack([str(src)])
1356             cache = black.read_cache(mode)
1357             self.assertIn(src, cache)
1358
1359     def test_cache_single_file_already_cached(self) -> None:
1360         mode = DEFAULT_MODE
1361         with cache_dir() as workspace:
1362             src = (workspace / "test.py").resolve()
1363             with src.open("w") as fobj:
1364                 fobj.write("print('hello')")
1365             black.write_cache({}, [src], mode)
1366             self.invokeBlack([str(src)])
1367             with src.open("r") as fobj:
1368                 self.assertEqual(fobj.read(), "print('hello')")
1369
1370     @event_loop()
1371     def test_cache_multiple_files(self) -> None:
1372         mode = DEFAULT_MODE
1373         with cache_dir() as workspace, patch(
1374             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1375         ):
1376             one = (workspace / "one.py").resolve()
1377             with one.open("w") as fobj:
1378                 fobj.write("print('hello')")
1379             two = (workspace / "two.py").resolve()
1380             with two.open("w") as fobj:
1381                 fobj.write("print('hello')")
1382             black.write_cache({}, [one], mode)
1383             self.invokeBlack([str(workspace)])
1384             with one.open("r") as fobj:
1385                 self.assertEqual(fobj.read(), "print('hello')")
1386             with two.open("r") as fobj:
1387                 self.assertEqual(fobj.read(), 'print("hello")\n')
1388             cache = black.read_cache(mode)
1389             self.assertIn(one, cache)
1390             self.assertIn(two, cache)
1391
1392     def test_no_cache_when_writeback_diff(self) -> None:
1393         mode = DEFAULT_MODE
1394         with cache_dir() as workspace:
1395             src = (workspace / "test.py").resolve()
1396             with src.open("w") as fobj:
1397                 fobj.write("print('hello')")
1398             self.invokeBlack([str(src), "--diff"])
1399             cache_file = black.get_cache_file(mode)
1400             self.assertFalse(cache_file.exists())
1401
1402     def test_no_cache_when_stdin(self) -> None:
1403         mode = DEFAULT_MODE
1404         with cache_dir():
1405             result = CliRunner().invoke(
1406                 black.main, ["-"], input=BytesIO(b"print('hello')")
1407             )
1408             self.assertEqual(result.exit_code, 0)
1409             cache_file = black.get_cache_file(mode)
1410             self.assertFalse(cache_file.exists())
1411
1412     def test_read_cache_no_cachefile(self) -> None:
1413         mode = DEFAULT_MODE
1414         with cache_dir():
1415             self.assertEqual(black.read_cache(mode), {})
1416
1417     def test_write_cache_read_cache(self) -> None:
1418         mode = DEFAULT_MODE
1419         with cache_dir() as workspace:
1420             src = (workspace / "test.py").resolve()
1421             src.touch()
1422             black.write_cache({}, [src], mode)
1423             cache = black.read_cache(mode)
1424             self.assertIn(src, cache)
1425             self.assertEqual(cache[src], black.get_cache_info(src))
1426
1427     def test_filter_cached(self) -> None:
1428         with TemporaryDirectory() as workspace:
1429             path = Path(workspace)
1430             uncached = (path / "uncached").resolve()
1431             cached = (path / "cached").resolve()
1432             cached_but_changed = (path / "changed").resolve()
1433             uncached.touch()
1434             cached.touch()
1435             cached_but_changed.touch()
1436             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1437             todo, done = black.filter_cached(
1438                 cache, {uncached, cached, cached_but_changed}
1439             )
1440             self.assertEqual(todo, {uncached, cached_but_changed})
1441             self.assertEqual(done, {cached})
1442
1443     def test_write_cache_creates_directory_if_needed(self) -> None:
1444         mode = DEFAULT_MODE
1445         with cache_dir(exists=False) as workspace:
1446             self.assertFalse(workspace.exists())
1447             black.write_cache({}, [], mode)
1448             self.assertTrue(workspace.exists())
1449
1450     @event_loop()
1451     def test_failed_formatting_does_not_get_cached(self) -> None:
1452         mode = DEFAULT_MODE
1453         with cache_dir() as workspace, patch(
1454             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1455         ):
1456             failing = (workspace / "failing.py").resolve()
1457             with failing.open("w") as fobj:
1458                 fobj.write("not actually python")
1459             clean = (workspace / "clean.py").resolve()
1460             with clean.open("w") as fobj:
1461                 fobj.write('print("hello")\n')
1462             self.invokeBlack([str(workspace)], exit_code=123)
1463             cache = black.read_cache(mode)
1464             self.assertNotIn(failing, cache)
1465             self.assertIn(clean, cache)
1466
1467     def test_write_cache_write_fail(self) -> None:
1468         mode = DEFAULT_MODE
1469         with cache_dir(), patch.object(Path, "open") as mock:
1470             mock.side_effect = OSError
1471             black.write_cache({}, [], mode)
1472
1473     @event_loop()
1474     @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
1475     def test_works_in_mono_process_only_environment(self) -> None:
1476         with cache_dir() as workspace:
1477             for f in [
1478                 (workspace / "one.py").resolve(),
1479                 (workspace / "two.py").resolve(),
1480             ]:
1481                 f.write_text('print("hello")\n')
1482             self.invokeBlack([str(workspace)])
1483
1484     @event_loop()
1485     def test_check_diff_use_together(self) -> None:
1486         with cache_dir():
1487             # Files which will be reformatted.
1488             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1489             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1490             # Files which will not be reformatted.
1491             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1492             self.invokeBlack([str(src2), "--diff", "--check"])
1493             # Multi file command.
1494             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1495
1496     def test_no_files(self) -> None:
1497         with cache_dir():
1498             # Without an argument, black exits with error code 0.
1499             self.invokeBlack([])
1500
1501     def test_broken_symlink(self) -> None:
1502         with cache_dir() as workspace:
1503             symlink = workspace / "broken_link.py"
1504             try:
1505                 symlink.symlink_to("nonexistent.py")
1506             except OSError as e:
1507                 self.skipTest(f"Can't create symlinks: {e}")
1508             self.invokeBlack([str(workspace.resolve())])
1509
1510     def test_read_cache_line_lengths(self) -> None:
1511         mode = DEFAULT_MODE
1512         short_mode = replace(DEFAULT_MODE, line_length=1)
1513         with cache_dir() as workspace:
1514             path = (workspace / "file.py").resolve()
1515             path.touch()
1516             black.write_cache({}, [path], mode)
1517             one = black.read_cache(mode)
1518             self.assertIn(path, one)
1519             two = black.read_cache(short_mode)
1520             self.assertNotIn(path, two)
1521
1522     def test_tricky_unicode_symbols(self) -> None:
1523         source, expected = read_data("tricky_unicode_symbols")
1524         actual = fs(source)
1525         self.assertFormatEqual(expected, actual)
1526         black.assert_equivalent(source, actual)
1527         black.assert_stable(source, actual, DEFAULT_MODE)
1528
1529     def test_single_file_force_pyi(self) -> None:
1530         reg_mode = DEFAULT_MODE
1531         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1532         contents, expected = read_data("force_pyi")
1533         with cache_dir() as workspace:
1534             path = (workspace / "file.py").resolve()
1535             with open(path, "w") as fh:
1536                 fh.write(contents)
1537             self.invokeBlack([str(path), "--pyi"])
1538             with open(path, "r") as fh:
1539                 actual = fh.read()
1540             # verify cache with --pyi is separate
1541             pyi_cache = black.read_cache(pyi_mode)
1542             self.assertIn(path, pyi_cache)
1543             normal_cache = black.read_cache(reg_mode)
1544             self.assertNotIn(path, normal_cache)
1545         self.assertEqual(actual, expected)
1546
1547     @event_loop()
1548     def test_multi_file_force_pyi(self) -> None:
1549         reg_mode = DEFAULT_MODE
1550         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
1551         contents, expected = read_data("force_pyi")
1552         with cache_dir() as workspace:
1553             paths = [
1554                 (workspace / "file1.py").resolve(),
1555                 (workspace / "file2.py").resolve(),
1556             ]
1557             for path in paths:
1558                 with open(path, "w") as fh:
1559                     fh.write(contents)
1560             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1561             for path in paths:
1562                 with open(path, "r") as fh:
1563                     actual = fh.read()
1564                 self.assertEqual(actual, expected)
1565             # verify cache with --pyi is separate
1566             pyi_cache = black.read_cache(pyi_mode)
1567             normal_cache = black.read_cache(reg_mode)
1568             for path in paths:
1569                 self.assertIn(path, pyi_cache)
1570                 self.assertNotIn(path, normal_cache)
1571
1572     def test_pipe_force_pyi(self) -> None:
1573         source, expected = read_data("force_pyi")
1574         result = CliRunner().invoke(
1575             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1576         )
1577         self.assertEqual(result.exit_code, 0)
1578         actual = result.output
1579         self.assertFormatEqual(actual, expected)
1580
1581     def test_single_file_force_py36(self) -> None:
1582         reg_mode = DEFAULT_MODE
1583         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1584         source, expected = read_data("force_py36")
1585         with cache_dir() as workspace:
1586             path = (workspace / "file.py").resolve()
1587             with open(path, "w") as fh:
1588                 fh.write(source)
1589             self.invokeBlack([str(path), *PY36_ARGS])
1590             with open(path, "r") as fh:
1591                 actual = fh.read()
1592             # verify cache with --target-version is separate
1593             py36_cache = black.read_cache(py36_mode)
1594             self.assertIn(path, py36_cache)
1595             normal_cache = black.read_cache(reg_mode)
1596             self.assertNotIn(path, normal_cache)
1597         self.assertEqual(actual, expected)
1598
1599     @event_loop()
1600     def test_multi_file_force_py36(self) -> None:
1601         reg_mode = DEFAULT_MODE
1602         py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
1603         source, expected = read_data("force_py36")
1604         with cache_dir() as workspace:
1605             paths = [
1606                 (workspace / "file1.py").resolve(),
1607                 (workspace / "file2.py").resolve(),
1608             ]
1609             for path in paths:
1610                 with open(path, "w") as fh:
1611                     fh.write(source)
1612             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1613             for path in paths:
1614                 with open(path, "r") as fh:
1615                     actual = fh.read()
1616                 self.assertEqual(actual, expected)
1617             # verify cache with --target-version is separate
1618             pyi_cache = black.read_cache(py36_mode)
1619             normal_cache = black.read_cache(reg_mode)
1620             for path in paths:
1621                 self.assertIn(path, pyi_cache)
1622                 self.assertNotIn(path, normal_cache)
1623
1624     def test_collections(self) -> None:
1625         source, expected = read_data("collections")
1626         actual = fs(source)
1627         self.assertFormatEqual(expected, actual)
1628         black.assert_equivalent(source, actual)
1629         black.assert_stable(source, actual, DEFAULT_MODE)
1630
1631     def test_pipe_force_py36(self) -> None:
1632         source, expected = read_data("force_py36")
1633         result = CliRunner().invoke(
1634             black.main,
1635             ["-", "-q", "--target-version=py36"],
1636             input=BytesIO(source.encode("utf8")),
1637         )
1638         self.assertEqual(result.exit_code, 0)
1639         actual = result.output
1640         self.assertFormatEqual(actual, expected)
1641
1642     def test_include_exclude(self) -> None:
1643         path = THIS_DIR / "data" / "include_exclude_tests"
1644         include = re.compile(r"\.pyi?$")
1645         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1646         report = black.Report()
1647         gitignore = PathSpec.from_lines("gitwildmatch", [])
1648         sources: List[Path] = []
1649         expected = [
1650             Path(path / "b/dont_exclude/a.py"),
1651             Path(path / "b/dont_exclude/a.pyi"),
1652         ]
1653         this_abs = THIS_DIR.resolve()
1654         sources.extend(
1655             black.gen_python_files(
1656                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1657             )
1658         )
1659         self.assertEqual(sorted(expected), sorted(sources))
1660
1661     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
1662     def test_exclude_for_issue_1572(self) -> None:
1663         # Exclude shouldn't touch files that were explicitly given to Black through the
1664         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
1665         # https://github.com/psf/black/issues/1572
1666         path = THIS_DIR / "data" / "include_exclude_tests"
1667         include = ""
1668         exclude = r"/exclude/|a\.py"
1669         src = str(path / "b/exclude/a.py")
1670         report = black.Report()
1671         expected = [Path(path / "b/exclude/a.py")]
1672         sources = list(
1673             black.get_sources(
1674                 ctx=FakeContext(),
1675                 src=(src,),
1676                 quiet=True,
1677                 verbose=False,
1678                 include=include,
1679                 exclude=exclude,
1680                 force_exclude=None,
1681                 report=report,
1682             )
1683         )
1684         self.assertEqual(sorted(expected), sorted(sources))
1685
1686     def test_gitignore_exclude(self) -> None:
1687         path = THIS_DIR / "data" / "include_exclude_tests"
1688         include = re.compile(r"\.pyi?$")
1689         exclude = re.compile(r"")
1690         report = black.Report()
1691         gitignore = PathSpec.from_lines(
1692             "gitwildmatch", ["exclude/", ".definitely_exclude"]
1693         )
1694         sources: List[Path] = []
1695         expected = [
1696             Path(path / "b/dont_exclude/a.py"),
1697             Path(path / "b/dont_exclude/a.pyi"),
1698         ]
1699         this_abs = THIS_DIR.resolve()
1700         sources.extend(
1701             black.gen_python_files(
1702                 path.iterdir(), this_abs, include, exclude, None, report, gitignore
1703             )
1704         )
1705         self.assertEqual(sorted(expected), sorted(sources))
1706
1707     def test_empty_include(self) -> None:
1708         path = THIS_DIR / "data" / "include_exclude_tests"
1709         report = black.Report()
1710         gitignore = PathSpec.from_lines("gitwildmatch", [])
1711         empty = re.compile(r"")
1712         sources: List[Path] = []
1713         expected = [
1714             Path(path / "b/exclude/a.pie"),
1715             Path(path / "b/exclude/a.py"),
1716             Path(path / "b/exclude/a.pyi"),
1717             Path(path / "b/dont_exclude/a.pie"),
1718             Path(path / "b/dont_exclude/a.py"),
1719             Path(path / "b/dont_exclude/a.pyi"),
1720             Path(path / "b/.definitely_exclude/a.pie"),
1721             Path(path / "b/.definitely_exclude/a.py"),
1722             Path(path / "b/.definitely_exclude/a.pyi"),
1723         ]
1724         this_abs = THIS_DIR.resolve()
1725         sources.extend(
1726             black.gen_python_files(
1727                 path.iterdir(),
1728                 this_abs,
1729                 empty,
1730                 re.compile(black.DEFAULT_EXCLUDES),
1731                 None,
1732                 report,
1733                 gitignore,
1734             )
1735         )
1736         self.assertEqual(sorted(expected), sorted(sources))
1737
1738     def test_empty_exclude(self) -> None:
1739         path = THIS_DIR / "data" / "include_exclude_tests"
1740         report = black.Report()
1741         gitignore = PathSpec.from_lines("gitwildmatch", [])
1742         empty = re.compile(r"")
1743         sources: List[Path] = []
1744         expected = [
1745             Path(path / "b/dont_exclude/a.py"),
1746             Path(path / "b/dont_exclude/a.pyi"),
1747             Path(path / "b/exclude/a.py"),
1748             Path(path / "b/exclude/a.pyi"),
1749             Path(path / "b/.definitely_exclude/a.py"),
1750             Path(path / "b/.definitely_exclude/a.pyi"),
1751         ]
1752         this_abs = THIS_DIR.resolve()
1753         sources.extend(
1754             black.gen_python_files(
1755                 path.iterdir(),
1756                 this_abs,
1757                 re.compile(black.DEFAULT_INCLUDES),
1758                 empty,
1759                 None,
1760                 report,
1761                 gitignore,
1762             )
1763         )
1764         self.assertEqual(sorted(expected), sorted(sources))
1765
1766     def test_invalid_include_exclude(self) -> None:
1767         for option in ["--include", "--exclude"]:
1768             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1769
1770     def test_preserves_line_endings(self) -> None:
1771         with TemporaryDirectory() as workspace:
1772             test_file = Path(workspace) / "test.py"
1773             for nl in ["\n", "\r\n"]:
1774                 contents = nl.join(["def f(  ):", "    pass"])
1775                 test_file.write_bytes(contents.encode())
1776                 ff(test_file, write_back=black.WriteBack.YES)
1777                 updated_contents: bytes = test_file.read_bytes()
1778                 self.assertIn(nl.encode(), updated_contents)
1779                 if nl == "\n":
1780                     self.assertNotIn(b"\r\n", updated_contents)
1781
1782     def test_preserves_line_endings_via_stdin(self) -> None:
1783         for nl in ["\n", "\r\n"]:
1784             contents = nl.join(["def f(  ):", "    pass"])
1785             runner = BlackRunner()
1786             result = runner.invoke(
1787                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1788             )
1789             self.assertEqual(result.exit_code, 0)
1790             output = runner.stdout_bytes
1791             self.assertIn(nl.encode("utf8"), output)
1792             if nl == "\n":
1793                 self.assertNotIn(b"\r\n", output)
1794
1795     def test_assert_equivalent_different_asts(self) -> None:
1796         with self.assertRaises(AssertionError):
1797             black.assert_equivalent("{}", "None")
1798
1799     def test_symlink_out_of_root_directory(self) -> None:
1800         path = MagicMock()
1801         root = THIS_DIR.resolve()
1802         child = MagicMock()
1803         include = re.compile(black.DEFAULT_INCLUDES)
1804         exclude = re.compile(black.DEFAULT_EXCLUDES)
1805         report = black.Report()
1806         gitignore = PathSpec.from_lines("gitwildmatch", [])
1807         # `child` should behave like a symlink which resolved path is clearly
1808         # outside of the `root` directory.
1809         path.iterdir.return_value = [child]
1810         child.resolve.return_value = Path("/a/b/c")
1811         child.as_posix.return_value = "/a/b/c"
1812         child.is_symlink.return_value = True
1813         try:
1814             list(
1815                 black.gen_python_files(
1816                     path.iterdir(), root, include, exclude, None, report, gitignore
1817                 )
1818             )
1819         except ValueError as ve:
1820             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1821         path.iterdir.assert_called_once()
1822         child.resolve.assert_called_once()
1823         child.is_symlink.assert_called_once()
1824         # `child` should behave like a strange file which resolved path is clearly
1825         # outside of the `root` directory.
1826         child.is_symlink.return_value = False
1827         with self.assertRaises(ValueError):
1828             list(
1829                 black.gen_python_files(
1830                     path.iterdir(), root, include, exclude, None, report, gitignore
1831                 )
1832             )
1833         path.iterdir.assert_called()
1834         self.assertEqual(path.iterdir.call_count, 2)
1835         child.resolve.assert_called()
1836         self.assertEqual(child.resolve.call_count, 2)
1837         child.is_symlink.assert_called()
1838         self.assertEqual(child.is_symlink.call_count, 2)
1839
1840     def test_shhh_click(self) -> None:
1841         try:
1842             from click import _unicodefun  # type: ignore
1843         except ModuleNotFoundError:
1844             self.skipTest("Incompatible Click version")
1845         if not hasattr(_unicodefun, "_verify_python3_env"):
1846             self.skipTest("Incompatible Click version")
1847         # First, let's see if Click is crashing with a preferred ASCII charset.
1848         with patch("locale.getpreferredencoding") as gpe:
1849             gpe.return_value = "ASCII"
1850             with self.assertRaises(RuntimeError):
1851                 _unicodefun._verify_python3_env()
1852         # Now, let's silence Click...
1853         black.patch_click()
1854         # ...and confirm it's silent.
1855         with patch("locale.getpreferredencoding") as gpe:
1856             gpe.return_value = "ASCII"
1857             try:
1858                 _unicodefun._verify_python3_env()
1859             except RuntimeError as re:
1860                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1861
1862     def test_root_logger_not_used_directly(self) -> None:
1863         def fail(*args: Any, **kwargs: Any) -> None:
1864             self.fail("Record created with root logger")
1865
1866         with patch.multiple(
1867             logging.root,
1868             debug=fail,
1869             info=fail,
1870             warning=fail,
1871             error=fail,
1872             critical=fail,
1873             log=fail,
1874         ):
1875             ff(THIS_FILE)
1876
1877     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1878     def test_blackd_main(self) -> None:
1879         with patch("blackd.web.run_app"):
1880             result = CliRunner().invoke(blackd.main, [])
1881             if result.exception is not None:
1882                 raise result.exception
1883             self.assertEqual(result.exit_code, 0)
1884
1885     def test_invalid_config_return_code(self) -> None:
1886         tmp_file = Path(black.dump_to_file())
1887         try:
1888             tmp_config = Path(black.dump_to_file())
1889             tmp_config.unlink()
1890             args = ["--config", str(tmp_config), str(tmp_file)]
1891             self.invokeBlack(args, exit_code=2, ignore_config=False)
1892         finally:
1893             tmp_file.unlink()
1894
1895     def test_parse_pyproject_toml(self) -> None:
1896         test_toml_file = THIS_DIR / "test.toml"
1897         config = black.parse_pyproject_toml(str(test_toml_file))
1898         self.assertEqual(config["verbose"], 1)
1899         self.assertEqual(config["check"], "no")
1900         self.assertEqual(config["diff"], "y")
1901         self.assertEqual(config["color"], True)
1902         self.assertEqual(config["line_length"], 79)
1903         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1904         self.assertEqual(config["exclude"], r"\.pyi?$")
1905         self.assertEqual(config["include"], r"\.py?$")
1906
1907     def test_read_pyproject_toml(self) -> None:
1908         test_toml_file = THIS_DIR / "test.toml"
1909         fake_ctx = FakeContext()
1910         black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file))
1911         config = fake_ctx.default_map
1912         self.assertEqual(config["verbose"], "1")
1913         self.assertEqual(config["check"], "no")
1914         self.assertEqual(config["diff"], "y")
1915         self.assertEqual(config["color"], "True")
1916         self.assertEqual(config["line_length"], "79")
1917         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
1918         self.assertEqual(config["exclude"], r"\.pyi?$")
1919         self.assertEqual(config["include"], r"\.py?$")
1920
1921     def test_find_project_root(self) -> None:
1922         with TemporaryDirectory() as workspace:
1923             root = Path(workspace)
1924             test_dir = root / "test"
1925             test_dir.mkdir()
1926
1927             src_dir = root / "src"
1928             src_dir.mkdir()
1929
1930             root_pyproject = root / "pyproject.toml"
1931             root_pyproject.touch()
1932             src_pyproject = src_dir / "pyproject.toml"
1933             src_pyproject.touch()
1934             src_python = src_dir / "foo.py"
1935             src_python.touch()
1936
1937             self.assertEqual(
1938                 black.find_project_root((src_dir, test_dir)), root.resolve()
1939             )
1940             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
1941             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
1942
1943     def test_bpo_33660_workaround(self) -> None:
1944         if system() == "Windows":
1945             return
1946
1947         # https://bugs.python.org/issue33660
1948
1949         old_cwd = Path.cwd()
1950         try:
1951             root = Path("/")
1952             os.chdir(str(root))
1953             path = Path("workspace") / "project"
1954             report = black.Report(verbose=True)
1955             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
1956             self.assertEqual(normalized_path, "workspace/project")
1957         finally:
1958             os.chdir(str(old_cwd))
1959
1960
1961 class BlackDTestCase(AioHTTPTestCase):
1962     async def get_application(self) -> web.Application:
1963         return blackd.make_app()
1964
1965     # TODO: remove these decorators once the below is released
1966     # https://github.com/aio-libs/aiohttp/pull/3727
1967     @skip_if_exception("ClientOSError")
1968     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1969     @unittest_run_loop
1970     async def test_blackd_request_needs_formatting(self) -> None:
1971         response = await self.client.post("/", data=b"print('hello world')")
1972         self.assertEqual(response.status, 200)
1973         self.assertEqual(response.charset, "utf8")
1974         self.assertEqual(await response.read(), b'print("hello world")\n')
1975
1976     @skip_if_exception("ClientOSError")
1977     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1978     @unittest_run_loop
1979     async def test_blackd_request_no_change(self) -> None:
1980         response = await self.client.post("/", data=b'print("hello world")\n')
1981         self.assertEqual(response.status, 204)
1982         self.assertEqual(await response.read(), b"")
1983
1984     @skip_if_exception("ClientOSError")
1985     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1986     @unittest_run_loop
1987     async def test_blackd_request_syntax_error(self) -> None:
1988         response = await self.client.post("/", data=b"what even ( is")
1989         self.assertEqual(response.status, 400)
1990         content = await response.text()
1991         self.assertTrue(
1992             content.startswith("Cannot parse"),
1993             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1994         )
1995
1996     @skip_if_exception("ClientOSError")
1997     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1998     @unittest_run_loop
1999     async def test_blackd_unsupported_version(self) -> None:
2000         response = await self.client.post(
2001             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
2002         )
2003         self.assertEqual(response.status, 501)
2004
2005     @skip_if_exception("ClientOSError")
2006     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2007     @unittest_run_loop
2008     async def test_blackd_supported_version(self) -> None:
2009         response = await self.client.post(
2010             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
2011         )
2012         self.assertEqual(response.status, 200)
2013
2014     @skip_if_exception("ClientOSError")
2015     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2016     @unittest_run_loop
2017     async def test_blackd_invalid_python_variant(self) -> None:
2018         async def check(header_value: str, expected_status: int = 400) -> None:
2019             response = await self.client.post(
2020                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2021             )
2022             self.assertEqual(response.status, expected_status)
2023
2024         await check("lol")
2025         await check("ruby3.5")
2026         await check("pyi3.6")
2027         await check("py1.5")
2028         await check("2.8")
2029         await check("py2.8")
2030         await check("3.0")
2031         await check("pypy3.0")
2032         await check("jython3.4")
2033
2034     @skip_if_exception("ClientOSError")
2035     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2036     @unittest_run_loop
2037     async def test_blackd_pyi(self) -> None:
2038         source, expected = read_data("stub.pyi")
2039         response = await self.client.post(
2040             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
2041         )
2042         self.assertEqual(response.status, 200)
2043         self.assertEqual(await response.text(), expected)
2044
2045     @skip_if_exception("ClientOSError")
2046     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2047     @unittest_run_loop
2048     async def test_blackd_diff(self) -> None:
2049         diff_header = re.compile(
2050             r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
2051         )
2052
2053         source, _ = read_data("blackd_diff.py")
2054         expected, _ = read_data("blackd_diff.diff")
2055
2056         response = await self.client.post(
2057             "/", data=source, headers={blackd.DIFF_HEADER: "true"}
2058         )
2059         self.assertEqual(response.status, 200)
2060
2061         actual = await response.text()
2062         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
2063         self.assertEqual(actual, expected)
2064
2065     @skip_if_exception("ClientOSError")
2066     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2067     @unittest_run_loop
2068     async def test_blackd_python_variant(self) -> None:
2069         code = (
2070             "def f(\n"
2071             "    and_has_a_bunch_of,\n"
2072             "    very_long_arguments_too,\n"
2073             "    and_lots_of_them_as_well_lol,\n"
2074             "    **and_very_long_keyword_arguments\n"
2075             "):\n"
2076             "    pass\n"
2077         )
2078
2079         async def check(header_value: str, expected_status: int) -> None:
2080             response = await self.client.post(
2081                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
2082             )
2083             self.assertEqual(
2084                 response.status, expected_status, msg=await response.text()
2085             )
2086
2087         await check("3.6", 200)
2088         await check("py3.6", 200)
2089         await check("3.6,3.7", 200)
2090         await check("3.6,py3.7", 200)
2091         await check("py36,py37", 200)
2092         await check("36", 200)
2093         await check("3.6.4", 200)
2094
2095         await check("2", 204)
2096         await check("2.7", 204)
2097         await check("py2.7", 204)
2098         await check("3.4", 204)
2099         await check("py3.4", 204)
2100         await check("py34,py36", 204)
2101         await check("34", 204)
2102
2103     @skip_if_exception("ClientOSError")
2104     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2105     @unittest_run_loop
2106     async def test_blackd_line_length(self) -> None:
2107         response = await self.client.post(
2108             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
2109         )
2110         self.assertEqual(response.status, 200)
2111
2112     @skip_if_exception("ClientOSError")
2113     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2114     @unittest_run_loop
2115     async def test_blackd_invalid_line_length(self) -> None:
2116         response = await self.client.post(
2117             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
2118         )
2119         self.assertEqual(response.status, 400)
2120
2121     @skip_if_exception("ClientOSError")
2122     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
2123     @unittest_run_loop
2124     async def test_blackd_response_black_version_header(self) -> None:
2125         response = await self.client.post("/")
2126         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
2127
2128
2129 with open(black.__file__, "r", encoding="utf-8") as _bf:
2130     black_source_lines = _bf.readlines()
2131
2132
2133 def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
2134     """Show function calls `from black/__init__.py` as they happen.
2135
2136     Register this with `sys.settrace()` in a test you're debugging.
2137     """
2138     if event != "call":
2139         return tracefunc
2140
2141     stack = len(inspect.stack()) - 19
2142     stack *= 2
2143     filename = frame.f_code.co_filename
2144     lineno = frame.f_lineno
2145     func_sig_lineno = lineno - 1
2146     funcname = black_source_lines[func_sig_lineno].strip()
2147     while funcname.startswith("@"):
2148         func_sig_lineno += 1
2149         funcname = black_source_lines[func_sig_lineno].strip()
2150     if "black/__init__.py" in filename:
2151         print(f"{' ' * stack}{lineno}:{funcname}")
2152     return tracefunc
2153
2154
2155 if __name__ == "__main__":
2156     unittest.main(module="test_black")