]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add black version header to blackd responses (#1046)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import regex as re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator, TypeVar
14 import unittest
15 from unittest.mock import patch, MagicMock
16
17 from click import unstyle
18 from click.testing import CliRunner
19
20 import black
21 from black import Feature, TargetVersion
22
23 try:
24     import blackd
25     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
26     from aiohttp import web
27 except ImportError:
28     has_blackd_deps = False
29 else:
30     has_blackd_deps = True
31
32 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
33 fs = partial(black.format_str, mode=black.FileMode())
34 THIS_FILE = Path(__file__)
35 THIS_DIR = THIS_FILE.parent
36 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
37 PY36_ARGS = [
38     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
39 ]
40 T = TypeVar("T")
41 R = TypeVar("R")
42
43
44 def dump_to_stderr(*output: str) -> str:
45     return "\n" + "\n".join(output) + "\n"
46
47
48 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
49     """read_data('test_name') -> 'input', 'output'"""
50     if not name.endswith((".py", ".pyi", ".out", ".diff")):
51         name += ".py"
52     _input: List[str] = []
53     _output: List[str] = []
54     base_dir = THIS_DIR / "data" if data else THIS_DIR
55     with open(base_dir / name, "r", encoding="utf8") as test:
56         lines = test.readlines()
57     result = _input
58     for line in lines:
59         line = line.replace(EMPTY_LINE, "")
60         if line.rstrip() == "# output":
61             result = _output
62             continue
63
64         result.append(line)
65     if _input and not _output:
66         # If there's no output marker, treat the entire file as already pre-formatted.
67         _output = _input[:]
68     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
69
70
71 @contextmanager
72 def cache_dir(exists: bool = True) -> Iterator[Path]:
73     with TemporaryDirectory() as workspace:
74         cache_dir = Path(workspace)
75         if not exists:
76             cache_dir = cache_dir / "new"
77         with patch("black.CACHE_DIR", cache_dir):
78             yield cache_dir
79
80
81 @contextmanager
82 def event_loop(close: bool) -> Iterator[None]:
83     policy = asyncio.get_event_loop_policy()
84     loop = policy.new_event_loop()
85     asyncio.set_event_loop(loop)
86     try:
87         yield
88
89     finally:
90         if close:
91             loop.close()
92
93
94 @contextmanager
95 def skip_if_exception(e: str) -> Iterator[None]:
96     try:
97         yield
98     except Exception as exc:
99         if exc.__class__.__name__ == e:
100             unittest.skip(f"Encountered expected exception {exc}, skipping")
101         else:
102             raise
103
104
105 class BlackRunner(CliRunner):
106     """Modify CliRunner so that stderr is not merged with stdout.
107
108     This is a hack that can be removed once we depend on Click 7.x"""
109
110     def __init__(self) -> None:
111         self.stderrbuf = BytesIO()
112         self.stdoutbuf = BytesIO()
113         self.stdout_bytes = b""
114         self.stderr_bytes = b""
115         super().__init__()
116
117     @contextmanager
118     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
119         with super().isolation(*args, **kwargs) as output:
120             try:
121                 hold_stderr = sys.stderr
122                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
123                 yield output
124             finally:
125                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
126                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
127                 sys.stderr = hold_stderr
128
129
130 class BlackTestCase(unittest.TestCase):
131     maxDiff = None
132
133     def assertFormatEqual(self, expected: str, actual: str) -> None:
134         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
135             bdv: black.DebugVisitor[Any]
136             black.out("Expected tree:", fg="green")
137             try:
138                 exp_node = black.lib2to3_parse(expected)
139                 bdv = black.DebugVisitor()
140                 list(bdv.visit(exp_node))
141             except Exception as ve:
142                 black.err(str(ve))
143             black.out("Actual tree:", fg="red")
144             try:
145                 exp_node = black.lib2to3_parse(actual)
146                 bdv = black.DebugVisitor()
147                 list(bdv.visit(exp_node))
148             except Exception as ve:
149                 black.err(str(ve))
150         self.assertEqual(expected, actual)
151
152     def invokeBlack(
153         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
154     ) -> None:
155         runner = BlackRunner()
156         if ignore_config:
157             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
158         result = runner.invoke(black.main, args)
159         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
160
161     @patch("black.dump_to_file", dump_to_stderr)
162     def test_empty(self) -> None:
163         source = expected = ""
164         actual = fs(source)
165         self.assertFormatEqual(expected, actual)
166         black.assert_equivalent(source, actual)
167         black.assert_stable(source, actual, black.FileMode())
168
169     def test_empty_ff(self) -> None:
170         expected = ""
171         tmp_file = Path(black.dump_to_file())
172         try:
173             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
174             with open(tmp_file, encoding="utf8") as f:
175                 actual = f.read()
176         finally:
177             os.unlink(tmp_file)
178         self.assertFormatEqual(expected, actual)
179
180     @patch("black.dump_to_file", dump_to_stderr)
181     def test_self(self) -> None:
182         source, expected = read_data("test_black", data=False)
183         actual = fs(source)
184         self.assertFormatEqual(expected, actual)
185         black.assert_equivalent(source, actual)
186         black.assert_stable(source, actual, black.FileMode())
187         self.assertFalse(ff(THIS_FILE))
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_black(self) -> None:
191         source, expected = read_data("../black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
197
198     def test_piping(self) -> None:
199         source, expected = read_data("../black", data=False)
200         result = BlackRunner().invoke(
201             black.main,
202             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
203             input=BytesIO(source.encode("utf8")),
204         )
205         self.assertEqual(result.exit_code, 0)
206         self.assertFormatEqual(expected, result.output)
207         black.assert_equivalent(source, result.output)
208         black.assert_stable(source, result.output, black.FileMode())
209
210     def test_piping_diff(self) -> None:
211         diff_header = re.compile(
212             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
213             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
214         )
215         source, _ = read_data("expression.py")
216         expected, _ = read_data("expression.diff")
217         config = THIS_DIR / "data" / "empty_pyproject.toml"
218         args = [
219             "-",
220             "--fast",
221             f"--line-length={black.DEFAULT_LINE_LENGTH}",
222             "--diff",
223             f"--config={config}",
224         ]
225         result = BlackRunner().invoke(
226             black.main, args, input=BytesIO(source.encode("utf8"))
227         )
228         self.assertEqual(result.exit_code, 0)
229         actual = diff_header.sub("[Deterministic header]", result.output)
230         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
231         self.assertEqual(expected, actual)
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_setup(self) -> None:
235         source, expected = read_data("../setup", data=False)
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, black.FileMode())
240         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_function(self) -> None:
244         source, expected = read_data("function")
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_function2(self) -> None:
252         source, expected = read_data("function2")
253         actual = fs(source)
254         self.assertFormatEqual(expected, actual)
255         black.assert_equivalent(source, actual)
256         black.assert_stable(source, actual, black.FileMode())
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_function_trailing_comma(self) -> None:
260         source, expected = read_data("function_trailing_comma")
261         actual = fs(source)
262         self.assertFormatEqual(expected, actual)
263         black.assert_equivalent(source, actual)
264         black.assert_stable(source, actual, black.FileMode())
265
266     @patch("black.dump_to_file", dump_to_stderr)
267     def test_expression(self) -> None:
268         source, expected = read_data("expression")
269         actual = fs(source)
270         self.assertFormatEqual(expected, actual)
271         black.assert_equivalent(source, actual)
272         black.assert_stable(source, actual, black.FileMode())
273
274     @patch("black.dump_to_file", dump_to_stderr)
275     def test_pep_572(self) -> None:
276         source, expected = read_data("pep_572")
277         actual = fs(source)
278         self.assertFormatEqual(expected, actual)
279         black.assert_stable(source, actual, black.FileMode())
280         if sys.version_info >= (3, 8):
281             black.assert_equivalent(source, actual)
282
283     def test_pep_572_version_detection(self) -> None:
284         source, _ = read_data("pep_572")
285         root = black.lib2to3_parse(source)
286         features = black.get_features_used(root)
287         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
288         versions = black.detect_target_versions(root)
289         self.assertIn(black.TargetVersion.PY38, versions)
290
291     def test_expression_ff(self) -> None:
292         source, expected = read_data("expression")
293         tmp_file = Path(black.dump_to_file(source))
294         try:
295             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
296             with open(tmp_file, encoding="utf8") as f:
297                 actual = f.read()
298         finally:
299             os.unlink(tmp_file)
300         self.assertFormatEqual(expected, actual)
301         with patch("black.dump_to_file", dump_to_stderr):
302             black.assert_equivalent(source, actual)
303             black.assert_stable(source, actual, black.FileMode())
304
305     def test_expression_diff(self) -> None:
306         source, _ = read_data("expression.py")
307         expected, _ = read_data("expression.diff")
308         tmp_file = Path(black.dump_to_file(source))
309         diff_header = re.compile(
310             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
311             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
312         )
313         try:
314             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
315             self.assertEqual(result.exit_code, 0)
316         finally:
317             os.unlink(tmp_file)
318         actual = result.output
319         actual = diff_header.sub("[Deterministic header]", actual)
320         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
321         if expected != actual:
322             dump = black.dump_to_file(actual)
323             msg = (
324                 f"Expected diff isn't equal to the actual. If you made changes "
325                 f"to expression.py and this is an anticipated difference, "
326                 f"overwrite tests/data/expression.diff with {dump}"
327             )
328             self.assertEqual(expected, actual, msg)
329
330     @patch("black.dump_to_file", dump_to_stderr)
331     def test_fstring(self) -> None:
332         source, expected = read_data("fstring")
333         actual = fs(source)
334         self.assertFormatEqual(expected, actual)
335         black.assert_equivalent(source, actual)
336         black.assert_stable(source, actual, black.FileMode())
337
338     @patch("black.dump_to_file", dump_to_stderr)
339     def test_pep_570(self) -> None:
340         source, expected = read_data("pep_570")
341         actual = fs(source)
342         self.assertFormatEqual(expected, actual)
343         black.assert_stable(source, actual, black.FileMode())
344         if sys.version_info >= (3, 8):
345             black.assert_equivalent(source, actual)
346
347     def test_detect_pos_only_arguments(self) -> None:
348         source, _ = read_data("pep_570")
349         root = black.lib2to3_parse(source)
350         features = black.get_features_used(root)
351         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
352         versions = black.detect_target_versions(root)
353         self.assertIn(black.TargetVersion.PY38, versions)
354
355     @patch("black.dump_to_file", dump_to_stderr)
356     def test_string_quotes(self) -> None:
357         source, expected = read_data("string_quotes")
358         actual = fs(source)
359         self.assertFormatEqual(expected, actual)
360         black.assert_equivalent(source, actual)
361         black.assert_stable(source, actual, black.FileMode())
362         mode = black.FileMode(string_normalization=False)
363         not_normalized = fs(source, mode=mode)
364         self.assertFormatEqual(source, not_normalized)
365         black.assert_equivalent(source, not_normalized)
366         black.assert_stable(source, not_normalized, mode=mode)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_slices(self) -> None:
370         source, expected = read_data("slices")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, black.FileMode())
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_comments(self) -> None:
378         source, expected = read_data("comments")
379         actual = fs(source)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, black.FileMode())
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_comments2(self) -> None:
386         source, expected = read_data("comments2")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_equivalent(source, actual)
390         black.assert_stable(source, actual, black.FileMode())
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_comments3(self) -> None:
394         source, expected = read_data("comments3")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, black.FileMode())
399
400     @patch("black.dump_to_file", dump_to_stderr)
401     def test_comments4(self) -> None:
402         source, expected = read_data("comments4")
403         actual = fs(source)
404         self.assertFormatEqual(expected, actual)
405         black.assert_equivalent(source, actual)
406         black.assert_stable(source, actual, black.FileMode())
407
408     @patch("black.dump_to_file", dump_to_stderr)
409     def test_comments5(self) -> None:
410         source, expected = read_data("comments5")
411         actual = fs(source)
412         self.assertFormatEqual(expected, actual)
413         black.assert_equivalent(source, actual)
414         black.assert_stable(source, actual, black.FileMode())
415
416     @patch("black.dump_to_file", dump_to_stderr)
417     def test_comments6(self) -> None:
418         source, expected = read_data("comments6")
419         actual = fs(source)
420         self.assertFormatEqual(expected, actual)
421         black.assert_equivalent(source, actual)
422         black.assert_stable(source, actual, black.FileMode())
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_comments7(self) -> None:
426         source, expected = read_data("comments7")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, black.FileMode())
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_comment_after_escaped_newline(self) -> None:
434         source, expected = read_data("comment_after_escaped_newline")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, black.FileMode())
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_cantfit(self) -> None:
442         source, expected = read_data("cantfit")
443         actual = fs(source)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, black.FileMode())
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_import_spacing(self) -> None:
450         source, expected = read_data("import_spacing")
451         actual = fs(source)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_composition(self) -> None:
458         source, expected = read_data("composition")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_empty_lines(self) -> None:
466         source, expected = read_data("empty_lines")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_remove_parens(self) -> None:
474         source, expected = read_data("remove_parens")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_string_prefixes(self) -> None:
482         source, expected = read_data("string_prefixes")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, black.FileMode())
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_numeric_literals(self) -> None:
490         source, expected = read_data("numeric_literals")
491         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
492         actual = fs(source, mode=mode)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, mode)
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_numeric_literals_ignoring_underscores(self) -> None:
499         source, expected = read_data("numeric_literals_skip_underscores")
500         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
501         actual = fs(source, mode=mode)
502         self.assertFormatEqual(expected, actual)
503         black.assert_equivalent(source, actual)
504         black.assert_stable(source, actual, mode)
505
506     @patch("black.dump_to_file", dump_to_stderr)
507     def test_numeric_literals_py2(self) -> None:
508         source, expected = read_data("numeric_literals_py2")
509         actual = fs(source)
510         self.assertFormatEqual(expected, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_python2(self) -> None:
515         source, expected = read_data("python2")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_python2_print_function(self) -> None:
523         source, expected = read_data("python2_print_function")
524         mode = black.FileMode(target_versions={TargetVersion.PY27})
525         actual = fs(source, mode=mode)
526         self.assertFormatEqual(expected, actual)
527         black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, mode)
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_python2_unicode_literals(self) -> None:
532         source, expected = read_data("python2_unicode_literals")
533         actual = fs(source)
534         self.assertFormatEqual(expected, actual)
535         black.assert_equivalent(source, actual)
536         black.assert_stable(source, actual, black.FileMode())
537
538     @patch("black.dump_to_file", dump_to_stderr)
539     def test_stub(self) -> None:
540         mode = black.FileMode(is_pyi=True)
541         source, expected = read_data("stub.pyi")
542         actual = fs(source, mode=mode)
543         self.assertFormatEqual(expected, actual)
544         black.assert_stable(source, actual, mode)
545
546     @patch("black.dump_to_file", dump_to_stderr)
547     def test_async_as_identifier(self) -> None:
548         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
549         source, expected = read_data("async_as_identifier")
550         actual = fs(source)
551         self.assertFormatEqual(expected, actual)
552         major, minor = sys.version_info[:2]
553         if major < 3 or (major <= 3 and minor < 7):
554             black.assert_equivalent(source, actual)
555         black.assert_stable(source, actual, black.FileMode())
556         # ensure black can parse this when the target is 3.6
557         self.invokeBlack([str(source_path), "--target-version", "py36"])
558         # but not on 3.7, because async/await is no longer an identifier
559         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
560
561     @patch("black.dump_to_file", dump_to_stderr)
562     def test_python37(self) -> None:
563         source_path = (THIS_DIR / "data" / "python37.py").resolve()
564         source, expected = read_data("python37")
565         actual = fs(source)
566         self.assertFormatEqual(expected, actual)
567         major, minor = sys.version_info[:2]
568         if major > 3 or (major == 3 and minor >= 7):
569             black.assert_equivalent(source, actual)
570         black.assert_stable(source, actual, black.FileMode())
571         # ensure black can parse this when the target is 3.7
572         self.invokeBlack([str(source_path), "--target-version", "py37"])
573         # but not on 3.6, because we use async as a reserved keyword
574         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
575
576     @patch("black.dump_to_file", dump_to_stderr)
577     def test_fmtonoff(self) -> None:
578         source, expected = read_data("fmtonoff")
579         actual = fs(source)
580         self.assertFormatEqual(expected, actual)
581         black.assert_equivalent(source, actual)
582         black.assert_stable(source, actual, black.FileMode())
583
584     @patch("black.dump_to_file", dump_to_stderr)
585     def test_fmtonoff2(self) -> None:
586         source, expected = read_data("fmtonoff2")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         black.assert_equivalent(source, actual)
590         black.assert_stable(source, actual, black.FileMode())
591
592     @patch("black.dump_to_file", dump_to_stderr)
593     def test_remove_empty_parentheses_after_class(self) -> None:
594         source, expected = read_data("class_blank_parentheses")
595         actual = fs(source)
596         self.assertFormatEqual(expected, actual)
597         black.assert_equivalent(source, actual)
598         black.assert_stable(source, actual, black.FileMode())
599
600     @patch("black.dump_to_file", dump_to_stderr)
601     def test_new_line_between_class_and_code(self) -> None:
602         source, expected = read_data("class_methods_new_line")
603         actual = fs(source)
604         self.assertFormatEqual(expected, actual)
605         black.assert_equivalent(source, actual)
606         black.assert_stable(source, actual, black.FileMode())
607
608     @patch("black.dump_to_file", dump_to_stderr)
609     def test_bracket_match(self) -> None:
610         source, expected = read_data("bracketmatch")
611         actual = fs(source)
612         self.assertFormatEqual(expected, actual)
613         black.assert_equivalent(source, actual)
614         black.assert_stable(source, actual, black.FileMode())
615
616     @patch("black.dump_to_file", dump_to_stderr)
617     def test_tuple_assign(self) -> None:
618         source, expected = read_data("tupleassign")
619         actual = fs(source)
620         self.assertFormatEqual(expected, actual)
621         black.assert_equivalent(source, actual)
622         black.assert_stable(source, actual, black.FileMode())
623
624     @patch("black.dump_to_file", dump_to_stderr)
625     def test_beginning_backslash(self) -> None:
626         source, expected = read_data("beginning_backslash")
627         actual = fs(source)
628         self.assertFormatEqual(expected, actual)
629         black.assert_equivalent(source, actual)
630         black.assert_stable(source, actual, black.FileMode())
631
632     def test_tab_comment_indentation(self) -> None:
633         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
634         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
635         self.assertFormatEqual(contents_spc, fs(contents_spc))
636         self.assertFormatEqual(contents_spc, fs(contents_tab))
637
638         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
639         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
640         self.assertFormatEqual(contents_spc, fs(contents_spc))
641         self.assertFormatEqual(contents_spc, fs(contents_tab))
642
643         # mixed tabs and spaces (valid Python 2 code)
644         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
645         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
646         self.assertFormatEqual(contents_spc, fs(contents_spc))
647         self.assertFormatEqual(contents_spc, fs(contents_tab))
648
649         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
650         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
651         self.assertFormatEqual(contents_spc, fs(contents_spc))
652         self.assertFormatEqual(contents_spc, fs(contents_tab))
653
654     def test_report_verbose(self) -> None:
655         report = black.Report(verbose=True)
656         out_lines = []
657         err_lines = []
658
659         def out(msg: str, **kwargs: Any) -> None:
660             out_lines.append(msg)
661
662         def err(msg: str, **kwargs: Any) -> None:
663             err_lines.append(msg)
664
665         with patch("black.out", out), patch("black.err", err):
666             report.done(Path("f1"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 1)
668             self.assertEqual(len(err_lines), 0)
669             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
670             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
671             self.assertEqual(report.return_code, 0)
672             report.done(Path("f2"), black.Changed.YES)
673             self.assertEqual(len(out_lines), 2)
674             self.assertEqual(len(err_lines), 0)
675             self.assertEqual(out_lines[-1], "reformatted f2")
676             self.assertEqual(
677                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
678             )
679             report.done(Path("f3"), black.Changed.CACHED)
680             self.assertEqual(len(out_lines), 3)
681             self.assertEqual(len(err_lines), 0)
682             self.assertEqual(
683                 out_lines[-1], "f3 wasn't modified on disk since last run."
684             )
685             self.assertEqual(
686                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
687             )
688             self.assertEqual(report.return_code, 0)
689             report.check = True
690             self.assertEqual(report.return_code, 1)
691             report.check = False
692             report.failed(Path("e1"), "boom")
693             self.assertEqual(len(out_lines), 3)
694             self.assertEqual(len(err_lines), 1)
695             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
696             self.assertEqual(
697                 unstyle(str(report)),
698                 "1 file reformatted, 2 files left unchanged, "
699                 "1 file failed to reformat.",
700             )
701             self.assertEqual(report.return_code, 123)
702             report.done(Path("f3"), black.Changed.YES)
703             self.assertEqual(len(out_lines), 4)
704             self.assertEqual(len(err_lines), 1)
705             self.assertEqual(out_lines[-1], "reformatted f3")
706             self.assertEqual(
707                 unstyle(str(report)),
708                 "2 files reformatted, 2 files left unchanged, "
709                 "1 file failed to reformat.",
710             )
711             self.assertEqual(report.return_code, 123)
712             report.failed(Path("e2"), "boom")
713             self.assertEqual(len(out_lines), 4)
714             self.assertEqual(len(err_lines), 2)
715             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
716             self.assertEqual(
717                 unstyle(str(report)),
718                 "2 files reformatted, 2 files left unchanged, "
719                 "2 files failed to reformat.",
720             )
721             self.assertEqual(report.return_code, 123)
722             report.path_ignored(Path("wat"), "no match")
723             self.assertEqual(len(out_lines), 5)
724             self.assertEqual(len(err_lines), 2)
725             self.assertEqual(out_lines[-1], "wat ignored: no match")
726             self.assertEqual(
727                 unstyle(str(report)),
728                 "2 files reformatted, 2 files left unchanged, "
729                 "2 files failed to reformat.",
730             )
731             self.assertEqual(report.return_code, 123)
732             report.done(Path("f4"), black.Changed.NO)
733             self.assertEqual(len(out_lines), 6)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files reformatted, 3 files left unchanged, "
739                 "2 files failed to reformat.",
740             )
741             self.assertEqual(report.return_code, 123)
742             report.check = True
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files would be reformatted, 3 files would be left unchanged, "
746                 "2 files would fail to reformat.",
747             )
748
749     def test_report_quiet(self) -> None:
750         report = black.Report(quiet=True)
751         out_lines = []
752         err_lines = []
753
754         def out(msg: str, **kwargs: Any) -> None:
755             out_lines.append(msg)
756
757         def err(msg: str, **kwargs: Any) -> None:
758             err_lines.append(msg)
759
760         with patch("black.out", out), patch("black.err", err):
761             report.done(Path("f1"), black.Changed.NO)
762             self.assertEqual(len(out_lines), 0)
763             self.assertEqual(len(err_lines), 0)
764             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
765             self.assertEqual(report.return_code, 0)
766             report.done(Path("f2"), black.Changed.YES)
767             self.assertEqual(len(out_lines), 0)
768             self.assertEqual(len(err_lines), 0)
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
771             )
772             report.done(Path("f3"), black.Changed.CACHED)
773             self.assertEqual(len(out_lines), 0)
774             self.assertEqual(len(err_lines), 0)
775             self.assertEqual(
776                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
777             )
778             self.assertEqual(report.return_code, 0)
779             report.check = True
780             self.assertEqual(report.return_code, 1)
781             report.check = False
782             report.failed(Path("e1"), "boom")
783             self.assertEqual(len(out_lines), 0)
784             self.assertEqual(len(err_lines), 1)
785             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
786             self.assertEqual(
787                 unstyle(str(report)),
788                 "1 file reformatted, 2 files left unchanged, "
789                 "1 file failed to reformat.",
790             )
791             self.assertEqual(report.return_code, 123)
792             report.done(Path("f3"), black.Changed.YES)
793             self.assertEqual(len(out_lines), 0)
794             self.assertEqual(len(err_lines), 1)
795             self.assertEqual(
796                 unstyle(str(report)),
797                 "2 files reformatted, 2 files left unchanged, "
798                 "1 file failed to reformat.",
799             )
800             self.assertEqual(report.return_code, 123)
801             report.failed(Path("e2"), "boom")
802             self.assertEqual(len(out_lines), 0)
803             self.assertEqual(len(err_lines), 2)
804             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
805             self.assertEqual(
806                 unstyle(str(report)),
807                 "2 files reformatted, 2 files left unchanged, "
808                 "2 files failed to reformat.",
809             )
810             self.assertEqual(report.return_code, 123)
811             report.path_ignored(Path("wat"), "no match")
812             self.assertEqual(len(out_lines), 0)
813             self.assertEqual(len(err_lines), 2)
814             self.assertEqual(
815                 unstyle(str(report)),
816                 "2 files reformatted, 2 files left unchanged, "
817                 "2 files failed to reformat.",
818             )
819             self.assertEqual(report.return_code, 123)
820             report.done(Path("f4"), black.Changed.NO)
821             self.assertEqual(len(out_lines), 0)
822             self.assertEqual(len(err_lines), 2)
823             self.assertEqual(
824                 unstyle(str(report)),
825                 "2 files reformatted, 3 files left unchanged, "
826                 "2 files failed to reformat.",
827             )
828             self.assertEqual(report.return_code, 123)
829             report.check = True
830             self.assertEqual(
831                 unstyle(str(report)),
832                 "2 files would be reformatted, 3 files would be left unchanged, "
833                 "2 files would fail to reformat.",
834             )
835
836     def test_report_normal(self) -> None:
837         report = black.Report()
838         out_lines = []
839         err_lines = []
840
841         def out(msg: str, **kwargs: Any) -> None:
842             out_lines.append(msg)
843
844         def err(msg: str, **kwargs: Any) -> None:
845             err_lines.append(msg)
846
847         with patch("black.out", out), patch("black.err", err):
848             report.done(Path("f1"), black.Changed.NO)
849             self.assertEqual(len(out_lines), 0)
850             self.assertEqual(len(err_lines), 0)
851             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
852             self.assertEqual(report.return_code, 0)
853             report.done(Path("f2"), black.Changed.YES)
854             self.assertEqual(len(out_lines), 1)
855             self.assertEqual(len(err_lines), 0)
856             self.assertEqual(out_lines[-1], "reformatted f2")
857             self.assertEqual(
858                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
859             )
860             report.done(Path("f3"), black.Changed.CACHED)
861             self.assertEqual(len(out_lines), 1)
862             self.assertEqual(len(err_lines), 0)
863             self.assertEqual(out_lines[-1], "reformatted f2")
864             self.assertEqual(
865                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
866             )
867             self.assertEqual(report.return_code, 0)
868             report.check = True
869             self.assertEqual(report.return_code, 1)
870             report.check = False
871             report.failed(Path("e1"), "boom")
872             self.assertEqual(len(out_lines), 1)
873             self.assertEqual(len(err_lines), 1)
874             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
875             self.assertEqual(
876                 unstyle(str(report)),
877                 "1 file reformatted, 2 files left unchanged, "
878                 "1 file failed to reformat.",
879             )
880             self.assertEqual(report.return_code, 123)
881             report.done(Path("f3"), black.Changed.YES)
882             self.assertEqual(len(out_lines), 2)
883             self.assertEqual(len(err_lines), 1)
884             self.assertEqual(out_lines[-1], "reformatted f3")
885             self.assertEqual(
886                 unstyle(str(report)),
887                 "2 files reformatted, 2 files left unchanged, "
888                 "1 file failed to reformat.",
889             )
890             self.assertEqual(report.return_code, 123)
891             report.failed(Path("e2"), "boom")
892             self.assertEqual(len(out_lines), 2)
893             self.assertEqual(len(err_lines), 2)
894             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
895             self.assertEqual(
896                 unstyle(str(report)),
897                 "2 files reformatted, 2 files left unchanged, "
898                 "2 files failed to reformat.",
899             )
900             self.assertEqual(report.return_code, 123)
901             report.path_ignored(Path("wat"), "no match")
902             self.assertEqual(len(out_lines), 2)
903             self.assertEqual(len(err_lines), 2)
904             self.assertEqual(
905                 unstyle(str(report)),
906                 "2 files reformatted, 2 files left unchanged, "
907                 "2 files failed to reformat.",
908             )
909             self.assertEqual(report.return_code, 123)
910             report.done(Path("f4"), black.Changed.NO)
911             self.assertEqual(len(out_lines), 2)
912             self.assertEqual(len(err_lines), 2)
913             self.assertEqual(
914                 unstyle(str(report)),
915                 "2 files reformatted, 3 files left unchanged, "
916                 "2 files failed to reformat.",
917             )
918             self.assertEqual(report.return_code, 123)
919             report.check = True
920             self.assertEqual(
921                 unstyle(str(report)),
922                 "2 files would be reformatted, 3 files would be left unchanged, "
923                 "2 files would fail to reformat.",
924             )
925
926     def test_lib2to3_parse(self) -> None:
927         with self.assertRaises(black.InvalidInput):
928             black.lib2to3_parse("invalid syntax")
929
930         straddling = "x + y"
931         black.lib2to3_parse(straddling)
932         black.lib2to3_parse(straddling, {TargetVersion.PY27})
933         black.lib2to3_parse(straddling, {TargetVersion.PY36})
934         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
935
936         py2_only = "print x"
937         black.lib2to3_parse(py2_only)
938         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
939         with self.assertRaises(black.InvalidInput):
940             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
941         with self.assertRaises(black.InvalidInput):
942             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
943
944         py3_only = "exec(x, end=y)"
945         black.lib2to3_parse(py3_only)
946         with self.assertRaises(black.InvalidInput):
947             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
948         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
949         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
950
951     def test_get_features_used(self) -> None:
952         node = black.lib2to3_parse("def f(*, arg): ...\n")
953         self.assertEqual(black.get_features_used(node), set())
954         node = black.lib2to3_parse("def f(*, arg,): ...\n")
955         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
956         node = black.lib2to3_parse("f(*arg,)\n")
957         self.assertEqual(
958             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
959         )
960         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
961         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
962         node = black.lib2to3_parse("123_456\n")
963         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
964         node = black.lib2to3_parse("123456\n")
965         self.assertEqual(black.get_features_used(node), set())
966         source, expected = read_data("function")
967         node = black.lib2to3_parse(source)
968         expected_features = {
969             Feature.TRAILING_COMMA_IN_CALL,
970             Feature.TRAILING_COMMA_IN_DEF,
971             Feature.F_STRINGS,
972         }
973         self.assertEqual(black.get_features_used(node), expected_features)
974         node = black.lib2to3_parse(expected)
975         self.assertEqual(black.get_features_used(node), expected_features)
976         source, expected = read_data("expression")
977         node = black.lib2to3_parse(source)
978         self.assertEqual(black.get_features_used(node), set())
979         node = black.lib2to3_parse(expected)
980         self.assertEqual(black.get_features_used(node), set())
981
982     def test_get_future_imports(self) -> None:
983         node = black.lib2to3_parse("\n")
984         self.assertEqual(set(), black.get_future_imports(node))
985         node = black.lib2to3_parse("from __future__ import black\n")
986         self.assertEqual({"black"}, black.get_future_imports(node))
987         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
988         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
989         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
990         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
991         node = black.lib2to3_parse(
992             "from __future__ import multiple\nfrom __future__ import imports\n"
993         )
994         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
995         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
996         self.assertEqual({"black"}, black.get_future_imports(node))
997         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
998         self.assertEqual({"black"}, black.get_future_imports(node))
999         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1000         self.assertEqual(set(), black.get_future_imports(node))
1001         node = black.lib2to3_parse("from some.module import black\n")
1002         self.assertEqual(set(), black.get_future_imports(node))
1003         node = black.lib2to3_parse(
1004             "from __future__ import unicode_literals as _unicode_literals"
1005         )
1006         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1007         node = black.lib2to3_parse(
1008             "from __future__ import unicode_literals as _lol, print"
1009         )
1010         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1011
1012     def test_debug_visitor(self) -> None:
1013         source, _ = read_data("debug_visitor.py")
1014         expected, _ = read_data("debug_visitor.out")
1015         out_lines = []
1016         err_lines = []
1017
1018         def out(msg: str, **kwargs: Any) -> None:
1019             out_lines.append(msg)
1020
1021         def err(msg: str, **kwargs: Any) -> None:
1022             err_lines.append(msg)
1023
1024         with patch("black.out", out), patch("black.err", err):
1025             black.DebugVisitor.show(source)
1026         actual = "\n".join(out_lines) + "\n"
1027         log_name = ""
1028         if expected != actual:
1029             log_name = black.dump_to_file(*out_lines)
1030         self.assertEqual(
1031             expected,
1032             actual,
1033             f"AST print out is different. Actual version dumped to {log_name}",
1034         )
1035
1036     def test_format_file_contents(self) -> None:
1037         empty = ""
1038         mode = black.FileMode()
1039         with self.assertRaises(black.NothingChanged):
1040             black.format_file_contents(empty, mode=mode, fast=False)
1041         just_nl = "\n"
1042         with self.assertRaises(black.NothingChanged):
1043             black.format_file_contents(just_nl, mode=mode, fast=False)
1044         same = "j = [1, 2, 3]\n"
1045         with self.assertRaises(black.NothingChanged):
1046             black.format_file_contents(same, mode=mode, fast=False)
1047         different = "j = [1,2,3]"
1048         expected = same
1049         actual = black.format_file_contents(different, mode=mode, fast=False)
1050         self.assertEqual(expected, actual)
1051         invalid = "return if you can"
1052         with self.assertRaises(black.InvalidInput) as e:
1053             black.format_file_contents(invalid, mode=mode, fast=False)
1054         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1055
1056     def test_endmarker(self) -> None:
1057         n = black.lib2to3_parse("\n")
1058         self.assertEqual(n.type, black.syms.file_input)
1059         self.assertEqual(len(n.children), 1)
1060         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1061
1062     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1063     def test_assertFormatEqual(self) -> None:
1064         out_lines = []
1065         err_lines = []
1066
1067         def out(msg: str, **kwargs: Any) -> None:
1068             out_lines.append(msg)
1069
1070         def err(msg: str, **kwargs: Any) -> None:
1071             err_lines.append(msg)
1072
1073         with patch("black.out", out), patch("black.err", err):
1074             with self.assertRaises(AssertionError):
1075                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1076
1077         out_str = "".join(out_lines)
1078         self.assertTrue("Expected tree:" in out_str)
1079         self.assertTrue("Actual tree:" in out_str)
1080         self.assertEqual("".join(err_lines), "")
1081
1082     def test_cache_broken_file(self) -> None:
1083         mode = black.FileMode()
1084         with cache_dir() as workspace:
1085             cache_file = black.get_cache_file(mode)
1086             with cache_file.open("w") as fobj:
1087                 fobj.write("this is not a pickle")
1088             self.assertEqual(black.read_cache(mode), {})
1089             src = (workspace / "test.py").resolve()
1090             with src.open("w") as fobj:
1091                 fobj.write("print('hello')")
1092             self.invokeBlack([str(src)])
1093             cache = black.read_cache(mode)
1094             self.assertIn(src, cache)
1095
1096     def test_cache_single_file_already_cached(self) -> None:
1097         mode = black.FileMode()
1098         with cache_dir() as workspace:
1099             src = (workspace / "test.py").resolve()
1100             with src.open("w") as fobj:
1101                 fobj.write("print('hello')")
1102             black.write_cache({}, [src], mode)
1103             self.invokeBlack([str(src)])
1104             with src.open("r") as fobj:
1105                 self.assertEqual(fobj.read(), "print('hello')")
1106
1107     @event_loop(close=False)
1108     def test_cache_multiple_files(self) -> None:
1109         mode = black.FileMode()
1110         with cache_dir() as workspace, patch(
1111             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1112         ):
1113             one = (workspace / "one.py").resolve()
1114             with one.open("w") as fobj:
1115                 fobj.write("print('hello')")
1116             two = (workspace / "two.py").resolve()
1117             with two.open("w") as fobj:
1118                 fobj.write("print('hello')")
1119             black.write_cache({}, [one], mode)
1120             self.invokeBlack([str(workspace)])
1121             with one.open("r") as fobj:
1122                 self.assertEqual(fobj.read(), "print('hello')")
1123             with two.open("r") as fobj:
1124                 self.assertEqual(fobj.read(), 'print("hello")\n')
1125             cache = black.read_cache(mode)
1126             self.assertIn(one, cache)
1127             self.assertIn(two, cache)
1128
1129     def test_no_cache_when_writeback_diff(self) -> None:
1130         mode = black.FileMode()
1131         with cache_dir() as workspace:
1132             src = (workspace / "test.py").resolve()
1133             with src.open("w") as fobj:
1134                 fobj.write("print('hello')")
1135             self.invokeBlack([str(src), "--diff"])
1136             cache_file = black.get_cache_file(mode)
1137             self.assertFalse(cache_file.exists())
1138
1139     def test_no_cache_when_stdin(self) -> None:
1140         mode = black.FileMode()
1141         with cache_dir():
1142             result = CliRunner().invoke(
1143                 black.main, ["-"], input=BytesIO(b"print('hello')")
1144             )
1145             self.assertEqual(result.exit_code, 0)
1146             cache_file = black.get_cache_file(mode)
1147             self.assertFalse(cache_file.exists())
1148
1149     def test_read_cache_no_cachefile(self) -> None:
1150         mode = black.FileMode()
1151         with cache_dir():
1152             self.assertEqual(black.read_cache(mode), {})
1153
1154     def test_write_cache_read_cache(self) -> None:
1155         mode = black.FileMode()
1156         with cache_dir() as workspace:
1157             src = (workspace / "test.py").resolve()
1158             src.touch()
1159             black.write_cache({}, [src], mode)
1160             cache = black.read_cache(mode)
1161             self.assertIn(src, cache)
1162             self.assertEqual(cache[src], black.get_cache_info(src))
1163
1164     def test_filter_cached(self) -> None:
1165         with TemporaryDirectory() as workspace:
1166             path = Path(workspace)
1167             uncached = (path / "uncached").resolve()
1168             cached = (path / "cached").resolve()
1169             cached_but_changed = (path / "changed").resolve()
1170             uncached.touch()
1171             cached.touch()
1172             cached_but_changed.touch()
1173             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1174             todo, done = black.filter_cached(
1175                 cache, {uncached, cached, cached_but_changed}
1176             )
1177             self.assertEqual(todo, {uncached, cached_but_changed})
1178             self.assertEqual(done, {cached})
1179
1180     def test_write_cache_creates_directory_if_needed(self) -> None:
1181         mode = black.FileMode()
1182         with cache_dir(exists=False) as workspace:
1183             self.assertFalse(workspace.exists())
1184             black.write_cache({}, [], mode)
1185             self.assertTrue(workspace.exists())
1186
1187     @event_loop(close=False)
1188     def test_failed_formatting_does_not_get_cached(self) -> None:
1189         mode = black.FileMode()
1190         with cache_dir() as workspace, patch(
1191             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1192         ):
1193             failing = (workspace / "failing.py").resolve()
1194             with failing.open("w") as fobj:
1195                 fobj.write("not actually python")
1196             clean = (workspace / "clean.py").resolve()
1197             with clean.open("w") as fobj:
1198                 fobj.write('print("hello")\n')
1199             self.invokeBlack([str(workspace)], exit_code=123)
1200             cache = black.read_cache(mode)
1201             self.assertNotIn(failing, cache)
1202             self.assertIn(clean, cache)
1203
1204     def test_write_cache_write_fail(self) -> None:
1205         mode = black.FileMode()
1206         with cache_dir(), patch.object(Path, "open") as mock:
1207             mock.side_effect = OSError
1208             black.write_cache({}, [], mode)
1209
1210     @event_loop(close=False)
1211     def test_check_diff_use_together(self) -> None:
1212         with cache_dir():
1213             # Files which will be reformatted.
1214             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1215             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1216             # Files which will not be reformatted.
1217             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1218             self.invokeBlack([str(src2), "--diff", "--check"])
1219             # Multi file command.
1220             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1221
1222     def test_no_files(self) -> None:
1223         with cache_dir():
1224             # Without an argument, black exits with error code 0.
1225             self.invokeBlack([])
1226
1227     def test_broken_symlink(self) -> None:
1228         with cache_dir() as workspace:
1229             symlink = workspace / "broken_link.py"
1230             try:
1231                 symlink.symlink_to("nonexistent.py")
1232             except OSError as e:
1233                 self.skipTest(f"Can't create symlinks: {e}")
1234             self.invokeBlack([str(workspace.resolve())])
1235
1236     def test_read_cache_line_lengths(self) -> None:
1237         mode = black.FileMode()
1238         short_mode = black.FileMode(line_length=1)
1239         with cache_dir() as workspace:
1240             path = (workspace / "file.py").resolve()
1241             path.touch()
1242             black.write_cache({}, [path], mode)
1243             one = black.read_cache(mode)
1244             self.assertIn(path, one)
1245             two = black.read_cache(short_mode)
1246             self.assertNotIn(path, two)
1247
1248     def test_tricky_unicode_symbols(self) -> None:
1249         source, expected = read_data("tricky_unicode_symbols")
1250         actual = fs(source)
1251         self.assertFormatEqual(expected, actual)
1252         black.assert_equivalent(source, actual)
1253         black.assert_stable(source, actual, black.FileMode())
1254
1255     def test_single_file_force_pyi(self) -> None:
1256         reg_mode = black.FileMode()
1257         pyi_mode = black.FileMode(is_pyi=True)
1258         contents, expected = read_data("force_pyi")
1259         with cache_dir() as workspace:
1260             path = (workspace / "file.py").resolve()
1261             with open(path, "w") as fh:
1262                 fh.write(contents)
1263             self.invokeBlack([str(path), "--pyi"])
1264             with open(path, "r") as fh:
1265                 actual = fh.read()
1266             # verify cache with --pyi is separate
1267             pyi_cache = black.read_cache(pyi_mode)
1268             self.assertIn(path, pyi_cache)
1269             normal_cache = black.read_cache(reg_mode)
1270             self.assertNotIn(path, normal_cache)
1271         self.assertEqual(actual, expected)
1272
1273     @event_loop(close=False)
1274     def test_multi_file_force_pyi(self) -> None:
1275         reg_mode = black.FileMode()
1276         pyi_mode = black.FileMode(is_pyi=True)
1277         contents, expected = read_data("force_pyi")
1278         with cache_dir() as workspace:
1279             paths = [
1280                 (workspace / "file1.py").resolve(),
1281                 (workspace / "file2.py").resolve(),
1282             ]
1283             for path in paths:
1284                 with open(path, "w") as fh:
1285                     fh.write(contents)
1286             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1287             for path in paths:
1288                 with open(path, "r") as fh:
1289                     actual = fh.read()
1290                 self.assertEqual(actual, expected)
1291             # verify cache with --pyi is separate
1292             pyi_cache = black.read_cache(pyi_mode)
1293             normal_cache = black.read_cache(reg_mode)
1294             for path in paths:
1295                 self.assertIn(path, pyi_cache)
1296                 self.assertNotIn(path, normal_cache)
1297
1298     def test_pipe_force_pyi(self) -> None:
1299         source, expected = read_data("force_pyi")
1300         result = CliRunner().invoke(
1301             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1302         )
1303         self.assertEqual(result.exit_code, 0)
1304         actual = result.output
1305         self.assertFormatEqual(actual, expected)
1306
1307     def test_single_file_force_py36(self) -> None:
1308         reg_mode = black.FileMode()
1309         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1310         source, expected = read_data("force_py36")
1311         with cache_dir() as workspace:
1312             path = (workspace / "file.py").resolve()
1313             with open(path, "w") as fh:
1314                 fh.write(source)
1315             self.invokeBlack([str(path), *PY36_ARGS])
1316             with open(path, "r") as fh:
1317                 actual = fh.read()
1318             # verify cache with --target-version is separate
1319             py36_cache = black.read_cache(py36_mode)
1320             self.assertIn(path, py36_cache)
1321             normal_cache = black.read_cache(reg_mode)
1322             self.assertNotIn(path, normal_cache)
1323         self.assertEqual(actual, expected)
1324
1325     @event_loop(close=False)
1326     def test_multi_file_force_py36(self) -> None:
1327         reg_mode = black.FileMode()
1328         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1329         source, expected = read_data("force_py36")
1330         with cache_dir() as workspace:
1331             paths = [
1332                 (workspace / "file1.py").resolve(),
1333                 (workspace / "file2.py").resolve(),
1334             ]
1335             for path in paths:
1336                 with open(path, "w") as fh:
1337                     fh.write(source)
1338             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1339             for path in paths:
1340                 with open(path, "r") as fh:
1341                     actual = fh.read()
1342                 self.assertEqual(actual, expected)
1343             # verify cache with --target-version is separate
1344             pyi_cache = black.read_cache(py36_mode)
1345             normal_cache = black.read_cache(reg_mode)
1346             for path in paths:
1347                 self.assertIn(path, pyi_cache)
1348                 self.assertNotIn(path, normal_cache)
1349
1350     def test_pipe_force_py36(self) -> None:
1351         source, expected = read_data("force_py36")
1352         result = CliRunner().invoke(
1353             black.main,
1354             ["-", "-q", "--target-version=py36"],
1355             input=BytesIO(source.encode("utf8")),
1356         )
1357         self.assertEqual(result.exit_code, 0)
1358         actual = result.output
1359         self.assertFormatEqual(actual, expected)
1360
1361     def test_include_exclude(self) -> None:
1362         path = THIS_DIR / "data" / "include_exclude_tests"
1363         include = re.compile(r"\.pyi?$")
1364         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1365         report = black.Report()
1366         sources: List[Path] = []
1367         expected = [
1368             Path(path / "b/dont_exclude/a.py"),
1369             Path(path / "b/dont_exclude/a.pyi"),
1370         ]
1371         this_abs = THIS_DIR.resolve()
1372         sources.extend(
1373             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1374         )
1375         self.assertEqual(sorted(expected), sorted(sources))
1376
1377     def test_empty_include(self) -> None:
1378         path = THIS_DIR / "data" / "include_exclude_tests"
1379         report = black.Report()
1380         empty = re.compile(r"")
1381         sources: List[Path] = []
1382         expected = [
1383             Path(path / "b/exclude/a.pie"),
1384             Path(path / "b/exclude/a.py"),
1385             Path(path / "b/exclude/a.pyi"),
1386             Path(path / "b/dont_exclude/a.pie"),
1387             Path(path / "b/dont_exclude/a.py"),
1388             Path(path / "b/dont_exclude/a.pyi"),
1389             Path(path / "b/.definitely_exclude/a.pie"),
1390             Path(path / "b/.definitely_exclude/a.py"),
1391             Path(path / "b/.definitely_exclude/a.pyi"),
1392         ]
1393         this_abs = THIS_DIR.resolve()
1394         sources.extend(
1395             black.gen_python_files_in_dir(
1396                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1397             )
1398         )
1399         self.assertEqual(sorted(expected), sorted(sources))
1400
1401     def test_empty_exclude(self) -> None:
1402         path = THIS_DIR / "data" / "include_exclude_tests"
1403         report = black.Report()
1404         empty = re.compile(r"")
1405         sources: List[Path] = []
1406         expected = [
1407             Path(path / "b/dont_exclude/a.py"),
1408             Path(path / "b/dont_exclude/a.pyi"),
1409             Path(path / "b/exclude/a.py"),
1410             Path(path / "b/exclude/a.pyi"),
1411             Path(path / "b/.definitely_exclude/a.py"),
1412             Path(path / "b/.definitely_exclude/a.pyi"),
1413         ]
1414         this_abs = THIS_DIR.resolve()
1415         sources.extend(
1416             black.gen_python_files_in_dir(
1417                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1418             )
1419         )
1420         self.assertEqual(sorted(expected), sorted(sources))
1421
1422     def test_invalid_include_exclude(self) -> None:
1423         for option in ["--include", "--exclude"]:
1424             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1425
1426     def test_preserves_line_endings(self) -> None:
1427         with TemporaryDirectory() as workspace:
1428             test_file = Path(workspace) / "test.py"
1429             for nl in ["\n", "\r\n"]:
1430                 contents = nl.join(["def f(  ):", "    pass"])
1431                 test_file.write_bytes(contents.encode())
1432                 ff(test_file, write_back=black.WriteBack.YES)
1433                 updated_contents: bytes = test_file.read_bytes()
1434                 self.assertIn(nl.encode(), updated_contents)
1435                 if nl == "\n":
1436                     self.assertNotIn(b"\r\n", updated_contents)
1437
1438     def test_preserves_line_endings_via_stdin(self) -> None:
1439         for nl in ["\n", "\r\n"]:
1440             contents = nl.join(["def f(  ):", "    pass"])
1441             runner = BlackRunner()
1442             result = runner.invoke(
1443                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1444             )
1445             self.assertEqual(result.exit_code, 0)
1446             output = runner.stdout_bytes
1447             self.assertIn(nl.encode("utf8"), output)
1448             if nl == "\n":
1449                 self.assertNotIn(b"\r\n", output)
1450
1451     def test_assert_equivalent_different_asts(self) -> None:
1452         with self.assertRaises(AssertionError):
1453             black.assert_equivalent("{}", "None")
1454
1455     def test_symlink_out_of_root_directory(self) -> None:
1456         path = MagicMock()
1457         root = THIS_DIR
1458         child = MagicMock()
1459         include = re.compile(black.DEFAULT_INCLUDES)
1460         exclude = re.compile(black.DEFAULT_EXCLUDES)
1461         report = black.Report()
1462         # `child` should behave like a symlink which resolved path is clearly
1463         # outside of the `root` directory.
1464         path.iterdir.return_value = [child]
1465         child.resolve.return_value = Path("/a/b/c")
1466         child.is_symlink.return_value = True
1467         try:
1468             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1469         except ValueError as ve:
1470             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1471         path.iterdir.assert_called_once()
1472         child.resolve.assert_called_once()
1473         child.is_symlink.assert_called_once()
1474         # `child` should behave like a strange file which resolved path is clearly
1475         # outside of the `root` directory.
1476         child.is_symlink.return_value = False
1477         with self.assertRaises(ValueError):
1478             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1479         path.iterdir.assert_called()
1480         self.assertEqual(path.iterdir.call_count, 2)
1481         child.resolve.assert_called()
1482         self.assertEqual(child.resolve.call_count, 2)
1483         child.is_symlink.assert_called()
1484         self.assertEqual(child.is_symlink.call_count, 2)
1485
1486     def test_shhh_click(self) -> None:
1487         try:
1488             from click import _unicodefun  # type: ignore
1489         except ModuleNotFoundError:
1490             self.skipTest("Incompatible Click version")
1491         if not hasattr(_unicodefun, "_verify_python3_env"):
1492             self.skipTest("Incompatible Click version")
1493         # First, let's see if Click is crashing with a preferred ASCII charset.
1494         with patch("locale.getpreferredencoding") as gpe:
1495             gpe.return_value = "ASCII"
1496             with self.assertRaises(RuntimeError):
1497                 _unicodefun._verify_python3_env()
1498         # Now, let's silence Click...
1499         black.patch_click()
1500         # ...and confirm it's silent.
1501         with patch("locale.getpreferredencoding") as gpe:
1502             gpe.return_value = "ASCII"
1503             try:
1504                 _unicodefun._verify_python3_env()
1505             except RuntimeError as re:
1506                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1507
1508     def test_root_logger_not_used_directly(self) -> None:
1509         def fail(*args: Any, **kwargs: Any) -> None:
1510             self.fail("Record created with root logger")
1511
1512         with patch.multiple(
1513             logging.root,
1514             debug=fail,
1515             info=fail,
1516             warning=fail,
1517             error=fail,
1518             critical=fail,
1519             log=fail,
1520         ):
1521             ff(THIS_FILE)
1522
1523     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1524     def test_blackd_main(self) -> None:
1525         with patch("blackd.web.run_app"):
1526             result = CliRunner().invoke(blackd.main, [])
1527             if result.exception is not None:
1528                 raise result.exception
1529             self.assertEqual(result.exit_code, 0)
1530
1531
1532 class BlackDTestCase(AioHTTPTestCase):
1533     async def get_application(self) -> web.Application:
1534         return blackd.make_app()
1535
1536     # TODO: remove these decorators once the below is released
1537     # https://github.com/aio-libs/aiohttp/pull/3727
1538     @skip_if_exception("ClientOSError")
1539     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1540     @unittest_run_loop
1541     async def test_blackd_request_needs_formatting(self) -> None:
1542         response = await self.client.post("/", data=b"print('hello world')")
1543         self.assertEqual(response.status, 200)
1544         self.assertEqual(response.charset, "utf8")
1545         self.assertEqual(await response.read(), b'print("hello world")\n')
1546
1547     @skip_if_exception("ClientOSError")
1548     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1549     @unittest_run_loop
1550     async def test_blackd_request_no_change(self) -> None:
1551         response = await self.client.post("/", data=b'print("hello world")\n')
1552         self.assertEqual(response.status, 204)
1553         self.assertEqual(await response.read(), b"")
1554
1555     @skip_if_exception("ClientOSError")
1556     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1557     @unittest_run_loop
1558     async def test_blackd_request_syntax_error(self) -> None:
1559         response = await self.client.post("/", data=b"what even ( is")
1560         self.assertEqual(response.status, 400)
1561         content = await response.text()
1562         self.assertTrue(
1563             content.startswith("Cannot parse"),
1564             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1565         )
1566
1567     @skip_if_exception("ClientOSError")
1568     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1569     @unittest_run_loop
1570     async def test_blackd_unsupported_version(self) -> None:
1571         response = await self.client.post(
1572             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1573         )
1574         self.assertEqual(response.status, 501)
1575
1576     @skip_if_exception("ClientOSError")
1577     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1578     @unittest_run_loop
1579     async def test_blackd_supported_version(self) -> None:
1580         response = await self.client.post(
1581             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1582         )
1583         self.assertEqual(response.status, 200)
1584
1585     @skip_if_exception("ClientOSError")
1586     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1587     @unittest_run_loop
1588     async def test_blackd_invalid_python_variant(self) -> None:
1589         async def check(header_value: str, expected_status: int = 400) -> None:
1590             response = await self.client.post(
1591                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1592             )
1593             self.assertEqual(response.status, expected_status)
1594
1595         await check("lol")
1596         await check("ruby3.5")
1597         await check("pyi3.6")
1598         await check("py1.5")
1599         await check("2.8")
1600         await check("py2.8")
1601         await check("3.0")
1602         await check("pypy3.0")
1603         await check("jython3.4")
1604
1605     @skip_if_exception("ClientOSError")
1606     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1607     @unittest_run_loop
1608     async def test_blackd_pyi(self) -> None:
1609         source, expected = read_data("stub.pyi")
1610         response = await self.client.post(
1611             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1612         )
1613         self.assertEqual(response.status, 200)
1614         self.assertEqual(await response.text(), expected)
1615
1616     @skip_if_exception("ClientOSError")
1617     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1618     @unittest_run_loop
1619     async def test_blackd_python_variant(self) -> None:
1620         code = (
1621             "def f(\n"
1622             "    and_has_a_bunch_of,\n"
1623             "    very_long_arguments_too,\n"
1624             "    and_lots_of_them_as_well_lol,\n"
1625             "    **and_very_long_keyword_arguments\n"
1626             "):\n"
1627             "    pass\n"
1628         )
1629
1630         async def check(header_value: str, expected_status: int) -> None:
1631             response = await self.client.post(
1632                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1633             )
1634             self.assertEqual(
1635                 response.status, expected_status, msg=await response.text()
1636             )
1637
1638         await check("3.6", 200)
1639         await check("py3.6", 200)
1640         await check("3.6,3.7", 200)
1641         await check("3.6,py3.7", 200)
1642         await check("py36,py37", 200)
1643         await check("36", 200)
1644         await check("3.6.4", 200)
1645
1646         await check("2", 204)
1647         await check("2.7", 204)
1648         await check("py2.7", 204)
1649         await check("3.4", 204)
1650         await check("py3.4", 204)
1651         await check("py34,py36", 204)
1652         await check("34", 204)
1653
1654     @skip_if_exception("ClientOSError")
1655     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1656     @unittest_run_loop
1657     async def test_blackd_line_length(self) -> None:
1658         response = await self.client.post(
1659             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1660         )
1661         self.assertEqual(response.status, 200)
1662
1663     @skip_if_exception("ClientOSError")
1664     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1665     @unittest_run_loop
1666     async def test_blackd_invalid_line_length(self) -> None:
1667         response = await self.client.post(
1668             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
1669         )
1670         self.assertEqual(response.status, 400)
1671
1672     @skip_if_exception("ClientOSError")
1673     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1674     @unittest_run_loop
1675     async def test_blackd_response_black_version_header(self) -> None:
1676         response = await self.client.post("/")
1677         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
1678
1679
1680 if __name__ == "__main__":
1681     unittest.main(module="test_black")