]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

restore cursor to same line of code, not same line of buffer (#989)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import regex as re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator, TypeVar
14 import unittest
15 from unittest.mock import patch, MagicMock
16
17 from click import unstyle
18 from click.testing import CliRunner
19
20 import black
21 from black import Feature, TargetVersion
22
23 try:
24     import blackd
25     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
26     from aiohttp import web
27 except ImportError:
28     has_blackd_deps = False
29 else:
30     has_blackd_deps = True
31
32 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
33 fs = partial(black.format_str, mode=black.FileMode())
34 THIS_FILE = Path(__file__)
35 THIS_DIR = THIS_FILE.parent
36 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
37 PY36_ARGS = [
38     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
39 ]
40 T = TypeVar("T")
41 R = TypeVar("R")
42
43
44 def dump_to_stderr(*output: str) -> str:
45     return "\n" + "\n".join(output) + "\n"
46
47
48 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
49     """read_data('test_name') -> 'input', 'output'"""
50     if not name.endswith((".py", ".pyi", ".out", ".diff")):
51         name += ".py"
52     _input: List[str] = []
53     _output: List[str] = []
54     base_dir = THIS_DIR / "data" if data else THIS_DIR
55     with open(base_dir / name, "r", encoding="utf8") as test:
56         lines = test.readlines()
57     result = _input
58     for line in lines:
59         line = line.replace(EMPTY_LINE, "")
60         if line.rstrip() == "# output":
61             result = _output
62             continue
63
64         result.append(line)
65     if _input and not _output:
66         # If there's no output marker, treat the entire file as already pre-formatted.
67         _output = _input[:]
68     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
69
70
71 @contextmanager
72 def cache_dir(exists: bool = True) -> Iterator[Path]:
73     with TemporaryDirectory() as workspace:
74         cache_dir = Path(workspace)
75         if not exists:
76             cache_dir = cache_dir / "new"
77         with patch("black.CACHE_DIR", cache_dir):
78             yield cache_dir
79
80
81 @contextmanager
82 def event_loop(close: bool) -> Iterator[None]:
83     policy = asyncio.get_event_loop_policy()
84     loop = policy.new_event_loop()
85     asyncio.set_event_loop(loop)
86     try:
87         yield
88
89     finally:
90         if close:
91             loop.close()
92
93
94 @contextmanager
95 def skip_if_exception(e: str) -> Iterator[None]:
96     try:
97         yield
98     except Exception as exc:
99         if exc.__class__.__name__ == e:
100             unittest.skip(f"Encountered expected exception {exc}, skipping")
101         else:
102             raise
103
104
105 class BlackRunner(CliRunner):
106     """Modify CliRunner so that stderr is not merged with stdout.
107
108     This is a hack that can be removed once we depend on Click 7.x"""
109
110     def __init__(self) -> None:
111         self.stderrbuf = BytesIO()
112         self.stdoutbuf = BytesIO()
113         self.stdout_bytes = b""
114         self.stderr_bytes = b""
115         super().__init__()
116
117     @contextmanager
118     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
119         with super().isolation(*args, **kwargs) as output:
120             try:
121                 hold_stderr = sys.stderr
122                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
123                 yield output
124             finally:
125                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
126                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
127                 sys.stderr = hold_stderr
128
129
130 class BlackTestCase(unittest.TestCase):
131     maxDiff = None
132
133     def assertFormatEqual(self, expected: str, actual: str) -> None:
134         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
135             bdv: black.DebugVisitor[Any]
136             black.out("Expected tree:", fg="green")
137             try:
138                 exp_node = black.lib2to3_parse(expected)
139                 bdv = black.DebugVisitor()
140                 list(bdv.visit(exp_node))
141             except Exception as ve:
142                 black.err(str(ve))
143             black.out("Actual tree:", fg="red")
144             try:
145                 exp_node = black.lib2to3_parse(actual)
146                 bdv = black.DebugVisitor()
147                 list(bdv.visit(exp_node))
148             except Exception as ve:
149                 black.err(str(ve))
150         self.assertEqual(expected, actual)
151
152     def invokeBlack(
153         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
154     ) -> None:
155         runner = BlackRunner()
156         if ignore_config:
157             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
158         result = runner.invoke(black.main, args)
159         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
160
161     @patch("black.dump_to_file", dump_to_stderr)
162     def checkSourceFile(self, name: str) -> None:
163         path = THIS_DIR.parent / name
164         source, expected = read_data(str(path), data=False)
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, black.FileMode())
169         self.assertFalse(ff(path))
170
171     @patch("black.dump_to_file", dump_to_stderr)
172     def test_empty(self) -> None:
173         source = expected = ""
174         actual = fs(source)
175         self.assertFormatEqual(expected, actual)
176         black.assert_equivalent(source, actual)
177         black.assert_stable(source, actual, black.FileMode())
178
179     def test_empty_ff(self) -> None:
180         expected = ""
181         tmp_file = Path(black.dump_to_file())
182         try:
183             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
184             with open(tmp_file, encoding="utf8") as f:
185                 actual = f.read()
186         finally:
187             os.unlink(tmp_file)
188         self.assertFormatEqual(expected, actual)
189
190     def test_self(self) -> None:
191         self.checkSourceFile("tests/test_black.py")
192
193     def test_black(self) -> None:
194         self.checkSourceFile("black.py")
195
196     def test_pygram(self) -> None:
197         self.checkSourceFile("blib2to3/pygram.py")
198
199     def test_pytree(self) -> None:
200         self.checkSourceFile("blib2to3/pytree.py")
201
202     def test_conv(self) -> None:
203         self.checkSourceFile("blib2to3/pgen2/conv.py")
204
205     def test_driver(self) -> None:
206         self.checkSourceFile("blib2to3/pgen2/driver.py")
207
208     def test_grammar(self) -> None:
209         self.checkSourceFile("blib2to3/pgen2/grammar.py")
210
211     def test_literals(self) -> None:
212         self.checkSourceFile("blib2to3/pgen2/literals.py")
213
214     def test_parse(self) -> None:
215         self.checkSourceFile("blib2to3/pgen2/parse.py")
216
217     def test_pgen(self) -> None:
218         self.checkSourceFile("blib2to3/pgen2/pgen.py")
219
220     def test_tokenize(self) -> None:
221         self.checkSourceFile("blib2to3/pgen2/tokenize.py")
222
223     def test_token(self) -> None:
224         self.checkSourceFile("blib2to3/pgen2/token.py")
225
226     def test_setup(self) -> None:
227         self.checkSourceFile("setup.py")
228
229     def test_piping(self) -> None:
230         source, expected = read_data("../black", data=False)
231         result = BlackRunner().invoke(
232             black.main,
233             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
234             input=BytesIO(source.encode("utf8")),
235         )
236         self.assertEqual(result.exit_code, 0)
237         self.assertFormatEqual(expected, result.output)
238         black.assert_equivalent(source, result.output)
239         black.assert_stable(source, result.output, black.FileMode())
240
241     def test_piping_diff(self) -> None:
242         diff_header = re.compile(
243             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
244             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
245         )
246         source, _ = read_data("expression.py")
247         expected, _ = read_data("expression.diff")
248         config = THIS_DIR / "data" / "empty_pyproject.toml"
249         args = [
250             "-",
251             "--fast",
252             f"--line-length={black.DEFAULT_LINE_LENGTH}",
253             "--diff",
254             f"--config={config}",
255         ]
256         result = BlackRunner().invoke(
257             black.main, args, input=BytesIO(source.encode("utf8"))
258         )
259         self.assertEqual(result.exit_code, 0)
260         actual = diff_header.sub("[Deterministic header]", result.output)
261         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
262         self.assertEqual(expected, actual)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_function(self) -> None:
266         source, expected = read_data("function")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, black.FileMode())
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_function2(self) -> None:
274         source, expected = read_data("function2")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, black.FileMode())
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_function_trailing_comma(self) -> None:
282         source, expected = read_data("function_trailing_comma")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, black.FileMode())
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_expression(self) -> None:
290         source, expected = read_data("expression")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, black.FileMode())
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_pep_572(self) -> None:
298         source, expected = read_data("pep_572")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_stable(source, actual, black.FileMode())
302         if sys.version_info >= (3, 8):
303             black.assert_equivalent(source, actual)
304
305     def test_pep_572_version_detection(self) -> None:
306         source, _ = read_data("pep_572")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_expression_ff(self) -> None:
314         source, expected = read_data("expression")
315         tmp_file = Path(black.dump_to_file(source))
316         try:
317             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
318             with open(tmp_file, encoding="utf8") as f:
319                 actual = f.read()
320         finally:
321             os.unlink(tmp_file)
322         self.assertFormatEqual(expected, actual)
323         with patch("black.dump_to_file", dump_to_stderr):
324             black.assert_equivalent(source, actual)
325             black.assert_stable(source, actual, black.FileMode())
326
327     def test_expression_diff(self) -> None:
328         source, _ = read_data("expression.py")
329         expected, _ = read_data("expression.diff")
330         tmp_file = Path(black.dump_to_file(source))
331         diff_header = re.compile(
332             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
333             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
334         )
335         try:
336             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
337             self.assertEqual(result.exit_code, 0)
338         finally:
339             os.unlink(tmp_file)
340         actual = result.output
341         actual = diff_header.sub("[Deterministic header]", actual)
342         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
343         if expected != actual:
344             dump = black.dump_to_file(actual)
345             msg = (
346                 f"Expected diff isn't equal to the actual. If you made changes "
347                 f"to expression.py and this is an anticipated difference, "
348                 f"overwrite tests/data/expression.diff with {dump}"
349             )
350             self.assertEqual(expected, actual, msg)
351
352     @patch("black.dump_to_file", dump_to_stderr)
353     def test_fstring(self) -> None:
354         source, expected = read_data("fstring")
355         actual = fs(source)
356         self.assertFormatEqual(expected, actual)
357         black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, black.FileMode())
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_pep_570(self) -> None:
362         source, expected = read_data("pep_570")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         black.assert_stable(source, actual, black.FileMode())
366         if sys.version_info >= (3, 8):
367             black.assert_equivalent(source, actual)
368
369     def test_detect_pos_only_arguments(self) -> None:
370         source, _ = read_data("pep_570")
371         root = black.lib2to3_parse(source)
372         features = black.get_features_used(root)
373         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
374         versions = black.detect_target_versions(root)
375         self.assertIn(black.TargetVersion.PY38, versions)
376
377     @patch("black.dump_to_file", dump_to_stderr)
378     def test_string_quotes(self) -> None:
379         source, expected = read_data("string_quotes")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         black.assert_equivalent(source, actual)
383         black.assert_stable(source, actual, black.FileMode())
384         mode = black.FileMode(string_normalization=False)
385         not_normalized = fs(source, mode=mode)
386         self.assertFormatEqual(source, not_normalized)
387         black.assert_equivalent(source, not_normalized)
388         black.assert_stable(source, not_normalized, mode=mode)
389
390     @patch("black.dump_to_file", dump_to_stderr)
391     def test_slices(self) -> None:
392         source, expected = read_data("slices")
393         actual = fs(source)
394         self.assertFormatEqual(expected, actual)
395         black.assert_equivalent(source, actual)
396         black.assert_stable(source, actual, black.FileMode())
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_comments(self) -> None:
400         source, expected = read_data("comments")
401         actual = fs(source)
402         self.assertFormatEqual(expected, actual)
403         black.assert_equivalent(source, actual)
404         black.assert_stable(source, actual, black.FileMode())
405
406     @patch("black.dump_to_file", dump_to_stderr)
407     def test_comments2(self) -> None:
408         source, expected = read_data("comments2")
409         actual = fs(source)
410         self.assertFormatEqual(expected, actual)
411         black.assert_equivalent(source, actual)
412         black.assert_stable(source, actual, black.FileMode())
413
414     @patch("black.dump_to_file", dump_to_stderr)
415     def test_comments3(self) -> None:
416         source, expected = read_data("comments3")
417         actual = fs(source)
418         self.assertFormatEqual(expected, actual)
419         black.assert_equivalent(source, actual)
420         black.assert_stable(source, actual, black.FileMode())
421
422     @patch("black.dump_to_file", dump_to_stderr)
423     def test_comments4(self) -> None:
424         source, expected = read_data("comments4")
425         actual = fs(source)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, black.FileMode())
429
430     @patch("black.dump_to_file", dump_to_stderr)
431     def test_comments5(self) -> None:
432         source, expected = read_data("comments5")
433         actual = fs(source)
434         self.assertFormatEqual(expected, actual)
435         black.assert_equivalent(source, actual)
436         black.assert_stable(source, actual, black.FileMode())
437
438     @patch("black.dump_to_file", dump_to_stderr)
439     def test_comments6(self) -> None:
440         source, expected = read_data("comments6")
441         actual = fs(source)
442         self.assertFormatEqual(expected, actual)
443         black.assert_equivalent(source, actual)
444         black.assert_stable(source, actual, black.FileMode())
445
446     @patch("black.dump_to_file", dump_to_stderr)
447     def test_comments7(self) -> None:
448         source, expected = read_data("comments7")
449         actual = fs(source)
450         self.assertFormatEqual(expected, actual)
451         black.assert_equivalent(source, actual)
452         black.assert_stable(source, actual, black.FileMode())
453
454     @patch("black.dump_to_file", dump_to_stderr)
455     def test_comment_after_escaped_newline(self) -> None:
456         source, expected = read_data("comment_after_escaped_newline")
457         actual = fs(source)
458         self.assertFormatEqual(expected, actual)
459         black.assert_equivalent(source, actual)
460         black.assert_stable(source, actual, black.FileMode())
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_cantfit(self) -> None:
464         source, expected = read_data("cantfit")
465         actual = fs(source)
466         self.assertFormatEqual(expected, actual)
467         black.assert_equivalent(source, actual)
468         black.assert_stable(source, actual, black.FileMode())
469
470     @patch("black.dump_to_file", dump_to_stderr)
471     def test_import_spacing(self) -> None:
472         source, expected = read_data("import_spacing")
473         actual = fs(source)
474         self.assertFormatEqual(expected, actual)
475         black.assert_equivalent(source, actual)
476         black.assert_stable(source, actual, black.FileMode())
477
478     @patch("black.dump_to_file", dump_to_stderr)
479     def test_composition(self) -> None:
480         source, expected = read_data("composition")
481         actual = fs(source)
482         self.assertFormatEqual(expected, actual)
483         black.assert_equivalent(source, actual)
484         black.assert_stable(source, actual, black.FileMode())
485
486     @patch("black.dump_to_file", dump_to_stderr)
487     def test_empty_lines(self) -> None:
488         source, expected = read_data("empty_lines")
489         actual = fs(source)
490         self.assertFormatEqual(expected, actual)
491         black.assert_equivalent(source, actual)
492         black.assert_stable(source, actual, black.FileMode())
493
494     @patch("black.dump_to_file", dump_to_stderr)
495     def test_remove_parens(self) -> None:
496         source, expected = read_data("remove_parens")
497         actual = fs(source)
498         self.assertFormatEqual(expected, actual)
499         black.assert_equivalent(source, actual)
500         black.assert_stable(source, actual, black.FileMode())
501
502     @patch("black.dump_to_file", dump_to_stderr)
503     def test_string_prefixes(self) -> None:
504         source, expected = read_data("string_prefixes")
505         actual = fs(source)
506         self.assertFormatEqual(expected, actual)
507         black.assert_equivalent(source, actual)
508         black.assert_stable(source, actual, black.FileMode())
509
510     @patch("black.dump_to_file", dump_to_stderr)
511     def test_numeric_literals(self) -> None:
512         source, expected = read_data("numeric_literals")
513         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
514         actual = fs(source, mode=mode)
515         self.assertFormatEqual(expected, actual)
516         black.assert_equivalent(source, actual)
517         black.assert_stable(source, actual, mode)
518
519     @patch("black.dump_to_file", dump_to_stderr)
520     def test_numeric_literals_ignoring_underscores(self) -> None:
521         source, expected = read_data("numeric_literals_skip_underscores")
522         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
523         actual = fs(source, mode=mode)
524         self.assertFormatEqual(expected, actual)
525         black.assert_equivalent(source, actual)
526         black.assert_stable(source, actual, mode)
527
528     @patch("black.dump_to_file", dump_to_stderr)
529     def test_numeric_literals_py2(self) -> None:
530         source, expected = read_data("numeric_literals_py2")
531         actual = fs(source)
532         self.assertFormatEqual(expected, actual)
533         black.assert_stable(source, actual, black.FileMode())
534
535     @patch("black.dump_to_file", dump_to_stderr)
536     def test_python2(self) -> None:
537         source, expected = read_data("python2")
538         actual = fs(source)
539         self.assertFormatEqual(expected, actual)
540         black.assert_equivalent(source, actual)
541         black.assert_stable(source, actual, black.FileMode())
542
543     @patch("black.dump_to_file", dump_to_stderr)
544     def test_python2_print_function(self) -> None:
545         source, expected = read_data("python2_print_function")
546         mode = black.FileMode(target_versions={TargetVersion.PY27})
547         actual = fs(source, mode=mode)
548         self.assertFormatEqual(expected, actual)
549         black.assert_equivalent(source, actual)
550         black.assert_stable(source, actual, mode)
551
552     @patch("black.dump_to_file", dump_to_stderr)
553     def test_python2_unicode_literals(self) -> None:
554         source, expected = read_data("python2_unicode_literals")
555         actual = fs(source)
556         self.assertFormatEqual(expected, actual)
557         black.assert_equivalent(source, actual)
558         black.assert_stable(source, actual, black.FileMode())
559
560     @patch("black.dump_to_file", dump_to_stderr)
561     def test_stub(self) -> None:
562         mode = black.FileMode(is_pyi=True)
563         source, expected = read_data("stub.pyi")
564         actual = fs(source, mode=mode)
565         self.assertFormatEqual(expected, actual)
566         black.assert_stable(source, actual, mode)
567
568     @patch("black.dump_to_file", dump_to_stderr)
569     def test_async_as_identifier(self) -> None:
570         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
571         source, expected = read_data("async_as_identifier")
572         actual = fs(source)
573         self.assertFormatEqual(expected, actual)
574         major, minor = sys.version_info[:2]
575         if major < 3 or (major <= 3 and minor < 7):
576             black.assert_equivalent(source, actual)
577         black.assert_stable(source, actual, black.FileMode())
578         # ensure black can parse this when the target is 3.6
579         self.invokeBlack([str(source_path), "--target-version", "py36"])
580         # but not on 3.7, because async/await is no longer an identifier
581         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
582
583     @patch("black.dump_to_file", dump_to_stderr)
584     def test_python37(self) -> None:
585         source_path = (THIS_DIR / "data" / "python37.py").resolve()
586         source, expected = read_data("python37")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         major, minor = sys.version_info[:2]
590         if major > 3 or (major == 3 and minor >= 7):
591             black.assert_equivalent(source, actual)
592         black.assert_stable(source, actual, black.FileMode())
593         # ensure black can parse this when the target is 3.7
594         self.invokeBlack([str(source_path), "--target-version", "py37"])
595         # but not on 3.6, because we use async as a reserved keyword
596         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
597
598     @patch("black.dump_to_file", dump_to_stderr)
599     def test_fmtonoff(self) -> None:
600         source, expected = read_data("fmtonoff")
601         actual = fs(source)
602         self.assertFormatEqual(expected, actual)
603         black.assert_equivalent(source, actual)
604         black.assert_stable(source, actual, black.FileMode())
605
606     @patch("black.dump_to_file", dump_to_stderr)
607     def test_fmtonoff2(self) -> None:
608         source, expected = read_data("fmtonoff2")
609         actual = fs(source)
610         self.assertFormatEqual(expected, actual)
611         black.assert_equivalent(source, actual)
612         black.assert_stable(source, actual, black.FileMode())
613
614     @patch("black.dump_to_file", dump_to_stderr)
615     def test_remove_empty_parentheses_after_class(self) -> None:
616         source, expected = read_data("class_blank_parentheses")
617         actual = fs(source)
618         self.assertFormatEqual(expected, actual)
619         black.assert_equivalent(source, actual)
620         black.assert_stable(source, actual, black.FileMode())
621
622     @patch("black.dump_to_file", dump_to_stderr)
623     def test_new_line_between_class_and_code(self) -> None:
624         source, expected = read_data("class_methods_new_line")
625         actual = fs(source)
626         self.assertFormatEqual(expected, actual)
627         black.assert_equivalent(source, actual)
628         black.assert_stable(source, actual, black.FileMode())
629
630     @patch("black.dump_to_file", dump_to_stderr)
631     def test_bracket_match(self) -> None:
632         source, expected = read_data("bracketmatch")
633         actual = fs(source)
634         self.assertFormatEqual(expected, actual)
635         black.assert_equivalent(source, actual)
636         black.assert_stable(source, actual, black.FileMode())
637
638     @patch("black.dump_to_file", dump_to_stderr)
639     def test_tuple_assign(self) -> None:
640         source, expected = read_data("tupleassign")
641         actual = fs(source)
642         self.assertFormatEqual(expected, actual)
643         black.assert_equivalent(source, actual)
644         black.assert_stable(source, actual, black.FileMode())
645
646     @patch("black.dump_to_file", dump_to_stderr)
647     def test_beginning_backslash(self) -> None:
648         source, expected = read_data("beginning_backslash")
649         actual = fs(source)
650         self.assertFormatEqual(expected, actual)
651         black.assert_equivalent(source, actual)
652         black.assert_stable(source, actual, black.FileMode())
653
654     def test_tab_comment_indentation(self) -> None:
655         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
656         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
657         self.assertFormatEqual(contents_spc, fs(contents_spc))
658         self.assertFormatEqual(contents_spc, fs(contents_tab))
659
660         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
661         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
662         self.assertFormatEqual(contents_spc, fs(contents_spc))
663         self.assertFormatEqual(contents_spc, fs(contents_tab))
664
665         # mixed tabs and spaces (valid Python 2 code)
666         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
667         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
668         self.assertFormatEqual(contents_spc, fs(contents_spc))
669         self.assertFormatEqual(contents_spc, fs(contents_tab))
670
671         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
672         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
673         self.assertFormatEqual(contents_spc, fs(contents_spc))
674         self.assertFormatEqual(contents_spc, fs(contents_tab))
675
676     def test_report_verbose(self) -> None:
677         report = black.Report(verbose=True)
678         out_lines = []
679         err_lines = []
680
681         def out(msg: str, **kwargs: Any) -> None:
682             out_lines.append(msg)
683
684         def err(msg: str, **kwargs: Any) -> None:
685             err_lines.append(msg)
686
687         with patch("black.out", out), patch("black.err", err):
688             report.done(Path("f1"), black.Changed.NO)
689             self.assertEqual(len(out_lines), 1)
690             self.assertEqual(len(err_lines), 0)
691             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
692             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
693             self.assertEqual(report.return_code, 0)
694             report.done(Path("f2"), black.Changed.YES)
695             self.assertEqual(len(out_lines), 2)
696             self.assertEqual(len(err_lines), 0)
697             self.assertEqual(out_lines[-1], "reformatted f2")
698             self.assertEqual(
699                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
700             )
701             report.done(Path("f3"), black.Changed.CACHED)
702             self.assertEqual(len(out_lines), 3)
703             self.assertEqual(len(err_lines), 0)
704             self.assertEqual(
705                 out_lines[-1], "f3 wasn't modified on disk since last run."
706             )
707             self.assertEqual(
708                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
709             )
710             self.assertEqual(report.return_code, 0)
711             report.check = True
712             self.assertEqual(report.return_code, 1)
713             report.check = False
714             report.failed(Path("e1"), "boom")
715             self.assertEqual(len(out_lines), 3)
716             self.assertEqual(len(err_lines), 1)
717             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "1 file reformatted, 2 files left unchanged, "
721                 "1 file failed to reformat.",
722             )
723             self.assertEqual(report.return_code, 123)
724             report.done(Path("f3"), black.Changed.YES)
725             self.assertEqual(len(out_lines), 4)
726             self.assertEqual(len(err_lines), 1)
727             self.assertEqual(out_lines[-1], "reformatted f3")
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 2 files left unchanged, "
731                 "1 file failed to reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.failed(Path("e2"), "boom")
735             self.assertEqual(len(out_lines), 4)
736             self.assertEqual(len(err_lines), 2)
737             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
738             self.assertEqual(
739                 unstyle(str(report)),
740                 "2 files reformatted, 2 files left unchanged, "
741                 "2 files failed to reformat.",
742             )
743             self.assertEqual(report.return_code, 123)
744             report.path_ignored(Path("wat"), "no match")
745             self.assertEqual(len(out_lines), 5)
746             self.assertEqual(len(err_lines), 2)
747             self.assertEqual(out_lines[-1], "wat ignored: no match")
748             self.assertEqual(
749                 unstyle(str(report)),
750                 "2 files reformatted, 2 files left unchanged, "
751                 "2 files failed to reformat.",
752             )
753             self.assertEqual(report.return_code, 123)
754             report.done(Path("f4"), black.Changed.NO)
755             self.assertEqual(len(out_lines), 6)
756             self.assertEqual(len(err_lines), 2)
757             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
758             self.assertEqual(
759                 unstyle(str(report)),
760                 "2 files reformatted, 3 files left unchanged, "
761                 "2 files failed to reformat.",
762             )
763             self.assertEqual(report.return_code, 123)
764             report.check = True
765             self.assertEqual(
766                 unstyle(str(report)),
767                 "2 files would be reformatted, 3 files would be left unchanged, "
768                 "2 files would fail to reformat.",
769             )
770
771     def test_report_quiet(self) -> None:
772         report = black.Report(quiet=True)
773         out_lines = []
774         err_lines = []
775
776         def out(msg: str, **kwargs: Any) -> None:
777             out_lines.append(msg)
778
779         def err(msg: str, **kwargs: Any) -> None:
780             err_lines.append(msg)
781
782         with patch("black.out", out), patch("black.err", err):
783             report.done(Path("f1"), black.Changed.NO)
784             self.assertEqual(len(out_lines), 0)
785             self.assertEqual(len(err_lines), 0)
786             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
787             self.assertEqual(report.return_code, 0)
788             report.done(Path("f2"), black.Changed.YES)
789             self.assertEqual(len(out_lines), 0)
790             self.assertEqual(len(err_lines), 0)
791             self.assertEqual(
792                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
793             )
794             report.done(Path("f3"), black.Changed.CACHED)
795             self.assertEqual(len(out_lines), 0)
796             self.assertEqual(len(err_lines), 0)
797             self.assertEqual(
798                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
799             )
800             self.assertEqual(report.return_code, 0)
801             report.check = True
802             self.assertEqual(report.return_code, 1)
803             report.check = False
804             report.failed(Path("e1"), "boom")
805             self.assertEqual(len(out_lines), 0)
806             self.assertEqual(len(err_lines), 1)
807             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
808             self.assertEqual(
809                 unstyle(str(report)),
810                 "1 file reformatted, 2 files left unchanged, "
811                 "1 file failed to reformat.",
812             )
813             self.assertEqual(report.return_code, 123)
814             report.done(Path("f3"), black.Changed.YES)
815             self.assertEqual(len(out_lines), 0)
816             self.assertEqual(len(err_lines), 1)
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files reformatted, 2 files left unchanged, "
820                 "1 file failed to reformat.",
821             )
822             self.assertEqual(report.return_code, 123)
823             report.failed(Path("e2"), "boom")
824             self.assertEqual(len(out_lines), 0)
825             self.assertEqual(len(err_lines), 2)
826             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
827             self.assertEqual(
828                 unstyle(str(report)),
829                 "2 files reformatted, 2 files left unchanged, "
830                 "2 files failed to reformat.",
831             )
832             self.assertEqual(report.return_code, 123)
833             report.path_ignored(Path("wat"), "no match")
834             self.assertEqual(len(out_lines), 0)
835             self.assertEqual(len(err_lines), 2)
836             self.assertEqual(
837                 unstyle(str(report)),
838                 "2 files reformatted, 2 files left unchanged, "
839                 "2 files failed to reformat.",
840             )
841             self.assertEqual(report.return_code, 123)
842             report.done(Path("f4"), black.Changed.NO)
843             self.assertEqual(len(out_lines), 0)
844             self.assertEqual(len(err_lines), 2)
845             self.assertEqual(
846                 unstyle(str(report)),
847                 "2 files reformatted, 3 files left unchanged, "
848                 "2 files failed to reformat.",
849             )
850             self.assertEqual(report.return_code, 123)
851             report.check = True
852             self.assertEqual(
853                 unstyle(str(report)),
854                 "2 files would be reformatted, 3 files would be left unchanged, "
855                 "2 files would fail to reformat.",
856             )
857
858     def test_report_normal(self) -> None:
859         report = black.Report()
860         out_lines = []
861         err_lines = []
862
863         def out(msg: str, **kwargs: Any) -> None:
864             out_lines.append(msg)
865
866         def err(msg: str, **kwargs: Any) -> None:
867             err_lines.append(msg)
868
869         with patch("black.out", out), patch("black.err", err):
870             report.done(Path("f1"), black.Changed.NO)
871             self.assertEqual(len(out_lines), 0)
872             self.assertEqual(len(err_lines), 0)
873             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
874             self.assertEqual(report.return_code, 0)
875             report.done(Path("f2"), black.Changed.YES)
876             self.assertEqual(len(out_lines), 1)
877             self.assertEqual(len(err_lines), 0)
878             self.assertEqual(out_lines[-1], "reformatted f2")
879             self.assertEqual(
880                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
881             )
882             report.done(Path("f3"), black.Changed.CACHED)
883             self.assertEqual(len(out_lines), 1)
884             self.assertEqual(len(err_lines), 0)
885             self.assertEqual(out_lines[-1], "reformatted f2")
886             self.assertEqual(
887                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
888             )
889             self.assertEqual(report.return_code, 0)
890             report.check = True
891             self.assertEqual(report.return_code, 1)
892             report.check = False
893             report.failed(Path("e1"), "boom")
894             self.assertEqual(len(out_lines), 1)
895             self.assertEqual(len(err_lines), 1)
896             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
897             self.assertEqual(
898                 unstyle(str(report)),
899                 "1 file reformatted, 2 files left unchanged, "
900                 "1 file failed to reformat.",
901             )
902             self.assertEqual(report.return_code, 123)
903             report.done(Path("f3"), black.Changed.YES)
904             self.assertEqual(len(out_lines), 2)
905             self.assertEqual(len(err_lines), 1)
906             self.assertEqual(out_lines[-1], "reformatted f3")
907             self.assertEqual(
908                 unstyle(str(report)),
909                 "2 files reformatted, 2 files left unchanged, "
910                 "1 file failed to reformat.",
911             )
912             self.assertEqual(report.return_code, 123)
913             report.failed(Path("e2"), "boom")
914             self.assertEqual(len(out_lines), 2)
915             self.assertEqual(len(err_lines), 2)
916             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
917             self.assertEqual(
918                 unstyle(str(report)),
919                 "2 files reformatted, 2 files left unchanged, "
920                 "2 files failed to reformat.",
921             )
922             self.assertEqual(report.return_code, 123)
923             report.path_ignored(Path("wat"), "no match")
924             self.assertEqual(len(out_lines), 2)
925             self.assertEqual(len(err_lines), 2)
926             self.assertEqual(
927                 unstyle(str(report)),
928                 "2 files reformatted, 2 files left unchanged, "
929                 "2 files failed to reformat.",
930             )
931             self.assertEqual(report.return_code, 123)
932             report.done(Path("f4"), black.Changed.NO)
933             self.assertEqual(len(out_lines), 2)
934             self.assertEqual(len(err_lines), 2)
935             self.assertEqual(
936                 unstyle(str(report)),
937                 "2 files reformatted, 3 files left unchanged, "
938                 "2 files failed to reformat.",
939             )
940             self.assertEqual(report.return_code, 123)
941             report.check = True
942             self.assertEqual(
943                 unstyle(str(report)),
944                 "2 files would be reformatted, 3 files would be left unchanged, "
945                 "2 files would fail to reformat.",
946             )
947
948     def test_lib2to3_parse(self) -> None:
949         with self.assertRaises(black.InvalidInput):
950             black.lib2to3_parse("invalid syntax")
951
952         straddling = "x + y"
953         black.lib2to3_parse(straddling)
954         black.lib2to3_parse(straddling, {TargetVersion.PY27})
955         black.lib2to3_parse(straddling, {TargetVersion.PY36})
956         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
957
958         py2_only = "print x"
959         black.lib2to3_parse(py2_only)
960         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
961         with self.assertRaises(black.InvalidInput):
962             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
963         with self.assertRaises(black.InvalidInput):
964             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
965
966         py3_only = "exec(x, end=y)"
967         black.lib2to3_parse(py3_only)
968         with self.assertRaises(black.InvalidInput):
969             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
970         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
971         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
972
973     def test_get_features_used(self) -> None:
974         node = black.lib2to3_parse("def f(*, arg): ...\n")
975         self.assertEqual(black.get_features_used(node), set())
976         node = black.lib2to3_parse("def f(*, arg,): ...\n")
977         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
978         node = black.lib2to3_parse("f(*arg,)\n")
979         self.assertEqual(
980             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
981         )
982         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
983         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
984         node = black.lib2to3_parse("123_456\n")
985         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
986         node = black.lib2to3_parse("123456\n")
987         self.assertEqual(black.get_features_used(node), set())
988         source, expected = read_data("function")
989         node = black.lib2to3_parse(source)
990         expected_features = {
991             Feature.TRAILING_COMMA_IN_CALL,
992             Feature.TRAILING_COMMA_IN_DEF,
993             Feature.F_STRINGS,
994         }
995         self.assertEqual(black.get_features_used(node), expected_features)
996         node = black.lib2to3_parse(expected)
997         self.assertEqual(black.get_features_used(node), expected_features)
998         source, expected = read_data("expression")
999         node = black.lib2to3_parse(source)
1000         self.assertEqual(black.get_features_used(node), set())
1001         node = black.lib2to3_parse(expected)
1002         self.assertEqual(black.get_features_used(node), set())
1003
1004     def test_get_future_imports(self) -> None:
1005         node = black.lib2to3_parse("\n")
1006         self.assertEqual(set(), black.get_future_imports(node))
1007         node = black.lib2to3_parse("from __future__ import black\n")
1008         self.assertEqual({"black"}, black.get_future_imports(node))
1009         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1010         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1011         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1012         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1013         node = black.lib2to3_parse(
1014             "from __future__ import multiple\nfrom __future__ import imports\n"
1015         )
1016         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1017         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1018         self.assertEqual({"black"}, black.get_future_imports(node))
1019         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1020         self.assertEqual({"black"}, black.get_future_imports(node))
1021         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1022         self.assertEqual(set(), black.get_future_imports(node))
1023         node = black.lib2to3_parse("from some.module import black\n")
1024         self.assertEqual(set(), black.get_future_imports(node))
1025         node = black.lib2to3_parse(
1026             "from __future__ import unicode_literals as _unicode_literals"
1027         )
1028         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1029         node = black.lib2to3_parse(
1030             "from __future__ import unicode_literals as _lol, print"
1031         )
1032         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1033
1034     def test_debug_visitor(self) -> None:
1035         source, _ = read_data("debug_visitor.py")
1036         expected, _ = read_data("debug_visitor.out")
1037         out_lines = []
1038         err_lines = []
1039
1040         def out(msg: str, **kwargs: Any) -> None:
1041             out_lines.append(msg)
1042
1043         def err(msg: str, **kwargs: Any) -> None:
1044             err_lines.append(msg)
1045
1046         with patch("black.out", out), patch("black.err", err):
1047             black.DebugVisitor.show(source)
1048         actual = "\n".join(out_lines) + "\n"
1049         log_name = ""
1050         if expected != actual:
1051             log_name = black.dump_to_file(*out_lines)
1052         self.assertEqual(
1053             expected,
1054             actual,
1055             f"AST print out is different. Actual version dumped to {log_name}",
1056         )
1057
1058     def test_format_file_contents(self) -> None:
1059         empty = ""
1060         mode = black.FileMode()
1061         with self.assertRaises(black.NothingChanged):
1062             black.format_file_contents(empty, mode=mode, fast=False)
1063         just_nl = "\n"
1064         with self.assertRaises(black.NothingChanged):
1065             black.format_file_contents(just_nl, mode=mode, fast=False)
1066         same = "j = [1, 2, 3]\n"
1067         with self.assertRaises(black.NothingChanged):
1068             black.format_file_contents(same, mode=mode, fast=False)
1069         different = "j = [1,2,3]"
1070         expected = same
1071         actual = black.format_file_contents(different, mode=mode, fast=False)
1072         self.assertEqual(expected, actual)
1073         invalid = "return if you can"
1074         with self.assertRaises(black.InvalidInput) as e:
1075             black.format_file_contents(invalid, mode=mode, fast=False)
1076         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1077
1078     def test_endmarker(self) -> None:
1079         n = black.lib2to3_parse("\n")
1080         self.assertEqual(n.type, black.syms.file_input)
1081         self.assertEqual(len(n.children), 1)
1082         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1083
1084     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1085     def test_assertFormatEqual(self) -> None:
1086         out_lines = []
1087         err_lines = []
1088
1089         def out(msg: str, **kwargs: Any) -> None:
1090             out_lines.append(msg)
1091
1092         def err(msg: str, **kwargs: Any) -> None:
1093             err_lines.append(msg)
1094
1095         with patch("black.out", out), patch("black.err", err):
1096             with self.assertRaises(AssertionError):
1097                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1098
1099         out_str = "".join(out_lines)
1100         self.assertTrue("Expected tree:" in out_str)
1101         self.assertTrue("Actual tree:" in out_str)
1102         self.assertEqual("".join(err_lines), "")
1103
1104     def test_cache_broken_file(self) -> None:
1105         mode = black.FileMode()
1106         with cache_dir() as workspace:
1107             cache_file = black.get_cache_file(mode)
1108             with cache_file.open("w") as fobj:
1109                 fobj.write("this is not a pickle")
1110             self.assertEqual(black.read_cache(mode), {})
1111             src = (workspace / "test.py").resolve()
1112             with src.open("w") as fobj:
1113                 fobj.write("print('hello')")
1114             self.invokeBlack([str(src)])
1115             cache = black.read_cache(mode)
1116             self.assertIn(src, cache)
1117
1118     def test_cache_single_file_already_cached(self) -> None:
1119         mode = black.FileMode()
1120         with cache_dir() as workspace:
1121             src = (workspace / "test.py").resolve()
1122             with src.open("w") as fobj:
1123                 fobj.write("print('hello')")
1124             black.write_cache({}, [src], mode)
1125             self.invokeBlack([str(src)])
1126             with src.open("r") as fobj:
1127                 self.assertEqual(fobj.read(), "print('hello')")
1128
1129     @event_loop(close=False)
1130     def test_cache_multiple_files(self) -> None:
1131         mode = black.FileMode()
1132         with cache_dir() as workspace, patch(
1133             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1134         ):
1135             one = (workspace / "one.py").resolve()
1136             with one.open("w") as fobj:
1137                 fobj.write("print('hello')")
1138             two = (workspace / "two.py").resolve()
1139             with two.open("w") as fobj:
1140                 fobj.write("print('hello')")
1141             black.write_cache({}, [one], mode)
1142             self.invokeBlack([str(workspace)])
1143             with one.open("r") as fobj:
1144                 self.assertEqual(fobj.read(), "print('hello')")
1145             with two.open("r") as fobj:
1146                 self.assertEqual(fobj.read(), 'print("hello")\n')
1147             cache = black.read_cache(mode)
1148             self.assertIn(one, cache)
1149             self.assertIn(two, cache)
1150
1151     def test_no_cache_when_writeback_diff(self) -> None:
1152         mode = black.FileMode()
1153         with cache_dir() as workspace:
1154             src = (workspace / "test.py").resolve()
1155             with src.open("w") as fobj:
1156                 fobj.write("print('hello')")
1157             self.invokeBlack([str(src), "--diff"])
1158             cache_file = black.get_cache_file(mode)
1159             self.assertFalse(cache_file.exists())
1160
1161     def test_no_cache_when_stdin(self) -> None:
1162         mode = black.FileMode()
1163         with cache_dir():
1164             result = CliRunner().invoke(
1165                 black.main, ["-"], input=BytesIO(b"print('hello')")
1166             )
1167             self.assertEqual(result.exit_code, 0)
1168             cache_file = black.get_cache_file(mode)
1169             self.assertFalse(cache_file.exists())
1170
1171     def test_read_cache_no_cachefile(self) -> None:
1172         mode = black.FileMode()
1173         with cache_dir():
1174             self.assertEqual(black.read_cache(mode), {})
1175
1176     def test_write_cache_read_cache(self) -> None:
1177         mode = black.FileMode()
1178         with cache_dir() as workspace:
1179             src = (workspace / "test.py").resolve()
1180             src.touch()
1181             black.write_cache({}, [src], mode)
1182             cache = black.read_cache(mode)
1183             self.assertIn(src, cache)
1184             self.assertEqual(cache[src], black.get_cache_info(src))
1185
1186     def test_filter_cached(self) -> None:
1187         with TemporaryDirectory() as workspace:
1188             path = Path(workspace)
1189             uncached = (path / "uncached").resolve()
1190             cached = (path / "cached").resolve()
1191             cached_but_changed = (path / "changed").resolve()
1192             uncached.touch()
1193             cached.touch()
1194             cached_but_changed.touch()
1195             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1196             todo, done = black.filter_cached(
1197                 cache, {uncached, cached, cached_but_changed}
1198             )
1199             self.assertEqual(todo, {uncached, cached_but_changed})
1200             self.assertEqual(done, {cached})
1201
1202     def test_write_cache_creates_directory_if_needed(self) -> None:
1203         mode = black.FileMode()
1204         with cache_dir(exists=False) as workspace:
1205             self.assertFalse(workspace.exists())
1206             black.write_cache({}, [], mode)
1207             self.assertTrue(workspace.exists())
1208
1209     @event_loop(close=False)
1210     def test_failed_formatting_does_not_get_cached(self) -> None:
1211         mode = black.FileMode()
1212         with cache_dir() as workspace, patch(
1213             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1214         ):
1215             failing = (workspace / "failing.py").resolve()
1216             with failing.open("w") as fobj:
1217                 fobj.write("not actually python")
1218             clean = (workspace / "clean.py").resolve()
1219             with clean.open("w") as fobj:
1220                 fobj.write('print("hello")\n')
1221             self.invokeBlack([str(workspace)], exit_code=123)
1222             cache = black.read_cache(mode)
1223             self.assertNotIn(failing, cache)
1224             self.assertIn(clean, cache)
1225
1226     def test_write_cache_write_fail(self) -> None:
1227         mode = black.FileMode()
1228         with cache_dir(), patch.object(Path, "open") as mock:
1229             mock.side_effect = OSError
1230             black.write_cache({}, [], mode)
1231
1232     @event_loop(close=False)
1233     def test_check_diff_use_together(self) -> None:
1234         with cache_dir():
1235             # Files which will be reformatted.
1236             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1237             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1238             # Files which will not be reformatted.
1239             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1240             self.invokeBlack([str(src2), "--diff", "--check"])
1241             # Multi file command.
1242             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1243
1244     def test_no_files(self) -> None:
1245         with cache_dir():
1246             # Without an argument, black exits with error code 0.
1247             self.invokeBlack([])
1248
1249     def test_broken_symlink(self) -> None:
1250         with cache_dir() as workspace:
1251             symlink = workspace / "broken_link.py"
1252             try:
1253                 symlink.symlink_to("nonexistent.py")
1254             except OSError as e:
1255                 self.skipTest(f"Can't create symlinks: {e}")
1256             self.invokeBlack([str(workspace.resolve())])
1257
1258     def test_read_cache_line_lengths(self) -> None:
1259         mode = black.FileMode()
1260         short_mode = black.FileMode(line_length=1)
1261         with cache_dir() as workspace:
1262             path = (workspace / "file.py").resolve()
1263             path.touch()
1264             black.write_cache({}, [path], mode)
1265             one = black.read_cache(mode)
1266             self.assertIn(path, one)
1267             two = black.read_cache(short_mode)
1268             self.assertNotIn(path, two)
1269
1270     def test_tricky_unicode_symbols(self) -> None:
1271         source, expected = read_data("tricky_unicode_symbols")
1272         actual = fs(source)
1273         self.assertFormatEqual(expected, actual)
1274         black.assert_equivalent(source, actual)
1275         black.assert_stable(source, actual, black.FileMode())
1276
1277     def test_single_file_force_pyi(self) -> None:
1278         reg_mode = black.FileMode()
1279         pyi_mode = black.FileMode(is_pyi=True)
1280         contents, expected = read_data("force_pyi")
1281         with cache_dir() as workspace:
1282             path = (workspace / "file.py").resolve()
1283             with open(path, "w") as fh:
1284                 fh.write(contents)
1285             self.invokeBlack([str(path), "--pyi"])
1286             with open(path, "r") as fh:
1287                 actual = fh.read()
1288             # verify cache with --pyi is separate
1289             pyi_cache = black.read_cache(pyi_mode)
1290             self.assertIn(path, pyi_cache)
1291             normal_cache = black.read_cache(reg_mode)
1292             self.assertNotIn(path, normal_cache)
1293         self.assertEqual(actual, expected)
1294
1295     @event_loop(close=False)
1296     def test_multi_file_force_pyi(self) -> None:
1297         reg_mode = black.FileMode()
1298         pyi_mode = black.FileMode(is_pyi=True)
1299         contents, expected = read_data("force_pyi")
1300         with cache_dir() as workspace:
1301             paths = [
1302                 (workspace / "file1.py").resolve(),
1303                 (workspace / "file2.py").resolve(),
1304             ]
1305             for path in paths:
1306                 with open(path, "w") as fh:
1307                     fh.write(contents)
1308             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1309             for path in paths:
1310                 with open(path, "r") as fh:
1311                     actual = fh.read()
1312                 self.assertEqual(actual, expected)
1313             # verify cache with --pyi is separate
1314             pyi_cache = black.read_cache(pyi_mode)
1315             normal_cache = black.read_cache(reg_mode)
1316             for path in paths:
1317                 self.assertIn(path, pyi_cache)
1318                 self.assertNotIn(path, normal_cache)
1319
1320     def test_pipe_force_pyi(self) -> None:
1321         source, expected = read_data("force_pyi")
1322         result = CliRunner().invoke(
1323             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1324         )
1325         self.assertEqual(result.exit_code, 0)
1326         actual = result.output
1327         self.assertFormatEqual(actual, expected)
1328
1329     def test_single_file_force_py36(self) -> None:
1330         reg_mode = black.FileMode()
1331         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1332         source, expected = read_data("force_py36")
1333         with cache_dir() as workspace:
1334             path = (workspace / "file.py").resolve()
1335             with open(path, "w") as fh:
1336                 fh.write(source)
1337             self.invokeBlack([str(path), *PY36_ARGS])
1338             with open(path, "r") as fh:
1339                 actual = fh.read()
1340             # verify cache with --target-version is separate
1341             py36_cache = black.read_cache(py36_mode)
1342             self.assertIn(path, py36_cache)
1343             normal_cache = black.read_cache(reg_mode)
1344             self.assertNotIn(path, normal_cache)
1345         self.assertEqual(actual, expected)
1346
1347     @event_loop(close=False)
1348     def test_multi_file_force_py36(self) -> None:
1349         reg_mode = black.FileMode()
1350         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1351         source, expected = read_data("force_py36")
1352         with cache_dir() as workspace:
1353             paths = [
1354                 (workspace / "file1.py").resolve(),
1355                 (workspace / "file2.py").resolve(),
1356             ]
1357             for path in paths:
1358                 with open(path, "w") as fh:
1359                     fh.write(source)
1360             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1361             for path in paths:
1362                 with open(path, "r") as fh:
1363                     actual = fh.read()
1364                 self.assertEqual(actual, expected)
1365             # verify cache with --target-version is separate
1366             pyi_cache = black.read_cache(py36_mode)
1367             normal_cache = black.read_cache(reg_mode)
1368             for path in paths:
1369                 self.assertIn(path, pyi_cache)
1370                 self.assertNotIn(path, normal_cache)
1371
1372     def test_pipe_force_py36(self) -> None:
1373         source, expected = read_data("force_py36")
1374         result = CliRunner().invoke(
1375             black.main,
1376             ["-", "-q", "--target-version=py36"],
1377             input=BytesIO(source.encode("utf8")),
1378         )
1379         self.assertEqual(result.exit_code, 0)
1380         actual = result.output
1381         self.assertFormatEqual(actual, expected)
1382
1383     def test_include_exclude(self) -> None:
1384         path = THIS_DIR / "data" / "include_exclude_tests"
1385         include = re.compile(r"\.pyi?$")
1386         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1387         report = black.Report()
1388         sources: List[Path] = []
1389         expected = [
1390             Path(path / "b/dont_exclude/a.py"),
1391             Path(path / "b/dont_exclude/a.pyi"),
1392         ]
1393         this_abs = THIS_DIR.resolve()
1394         sources.extend(
1395             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1396         )
1397         self.assertEqual(sorted(expected), sorted(sources))
1398
1399     def test_empty_include(self) -> None:
1400         path = THIS_DIR / "data" / "include_exclude_tests"
1401         report = black.Report()
1402         empty = re.compile(r"")
1403         sources: List[Path] = []
1404         expected = [
1405             Path(path / "b/exclude/a.pie"),
1406             Path(path / "b/exclude/a.py"),
1407             Path(path / "b/exclude/a.pyi"),
1408             Path(path / "b/dont_exclude/a.pie"),
1409             Path(path / "b/dont_exclude/a.py"),
1410             Path(path / "b/dont_exclude/a.pyi"),
1411             Path(path / "b/.definitely_exclude/a.pie"),
1412             Path(path / "b/.definitely_exclude/a.py"),
1413             Path(path / "b/.definitely_exclude/a.pyi"),
1414         ]
1415         this_abs = THIS_DIR.resolve()
1416         sources.extend(
1417             black.gen_python_files_in_dir(
1418                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1419             )
1420         )
1421         self.assertEqual(sorted(expected), sorted(sources))
1422
1423     def test_empty_exclude(self) -> None:
1424         path = THIS_DIR / "data" / "include_exclude_tests"
1425         report = black.Report()
1426         empty = re.compile(r"")
1427         sources: List[Path] = []
1428         expected = [
1429             Path(path / "b/dont_exclude/a.py"),
1430             Path(path / "b/dont_exclude/a.pyi"),
1431             Path(path / "b/exclude/a.py"),
1432             Path(path / "b/exclude/a.pyi"),
1433             Path(path / "b/.definitely_exclude/a.py"),
1434             Path(path / "b/.definitely_exclude/a.pyi"),
1435         ]
1436         this_abs = THIS_DIR.resolve()
1437         sources.extend(
1438             black.gen_python_files_in_dir(
1439                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1440             )
1441         )
1442         self.assertEqual(sorted(expected), sorted(sources))
1443
1444     def test_invalid_include_exclude(self) -> None:
1445         for option in ["--include", "--exclude"]:
1446             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1447
1448     def test_preserves_line_endings(self) -> None:
1449         with TemporaryDirectory() as workspace:
1450             test_file = Path(workspace) / "test.py"
1451             for nl in ["\n", "\r\n"]:
1452                 contents = nl.join(["def f(  ):", "    pass"])
1453                 test_file.write_bytes(contents.encode())
1454                 ff(test_file, write_back=black.WriteBack.YES)
1455                 updated_contents: bytes = test_file.read_bytes()
1456                 self.assertIn(nl.encode(), updated_contents)
1457                 if nl == "\n":
1458                     self.assertNotIn(b"\r\n", updated_contents)
1459
1460     def test_preserves_line_endings_via_stdin(self) -> None:
1461         for nl in ["\n", "\r\n"]:
1462             contents = nl.join(["def f(  ):", "    pass"])
1463             runner = BlackRunner()
1464             result = runner.invoke(
1465                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1466             )
1467             self.assertEqual(result.exit_code, 0)
1468             output = runner.stdout_bytes
1469             self.assertIn(nl.encode("utf8"), output)
1470             if nl == "\n":
1471                 self.assertNotIn(b"\r\n", output)
1472
1473     def test_assert_equivalent_different_asts(self) -> None:
1474         with self.assertRaises(AssertionError):
1475             black.assert_equivalent("{}", "None")
1476
1477     def test_symlink_out_of_root_directory(self) -> None:
1478         path = MagicMock()
1479         root = THIS_DIR
1480         child = MagicMock()
1481         include = re.compile(black.DEFAULT_INCLUDES)
1482         exclude = re.compile(black.DEFAULT_EXCLUDES)
1483         report = black.Report()
1484         # `child` should behave like a symlink which resolved path is clearly
1485         # outside of the `root` directory.
1486         path.iterdir.return_value = [child]
1487         child.resolve.return_value = Path("/a/b/c")
1488         child.is_symlink.return_value = True
1489         try:
1490             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1491         except ValueError as ve:
1492             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1493         path.iterdir.assert_called_once()
1494         child.resolve.assert_called_once()
1495         child.is_symlink.assert_called_once()
1496         # `child` should behave like a strange file which resolved path is clearly
1497         # outside of the `root` directory.
1498         child.is_symlink.return_value = False
1499         with self.assertRaises(ValueError):
1500             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1501         path.iterdir.assert_called()
1502         self.assertEqual(path.iterdir.call_count, 2)
1503         child.resolve.assert_called()
1504         self.assertEqual(child.resolve.call_count, 2)
1505         child.is_symlink.assert_called()
1506         self.assertEqual(child.is_symlink.call_count, 2)
1507
1508     def test_shhh_click(self) -> None:
1509         try:
1510             from click import _unicodefun  # type: ignore
1511         except ModuleNotFoundError:
1512             self.skipTest("Incompatible Click version")
1513         if not hasattr(_unicodefun, "_verify_python3_env"):
1514             self.skipTest("Incompatible Click version")
1515         # First, let's see if Click is crashing with a preferred ASCII charset.
1516         with patch("locale.getpreferredencoding") as gpe:
1517             gpe.return_value = "ASCII"
1518             with self.assertRaises(RuntimeError):
1519                 _unicodefun._verify_python3_env()
1520         # Now, let's silence Click...
1521         black.patch_click()
1522         # ...and confirm it's silent.
1523         with patch("locale.getpreferredencoding") as gpe:
1524             gpe.return_value = "ASCII"
1525             try:
1526                 _unicodefun._verify_python3_env()
1527             except RuntimeError as re:
1528                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1529
1530     def test_root_logger_not_used_directly(self) -> None:
1531         def fail(*args: Any, **kwargs: Any) -> None:
1532             self.fail("Record created with root logger")
1533
1534         with patch.multiple(
1535             logging.root,
1536             debug=fail,
1537             info=fail,
1538             warning=fail,
1539             error=fail,
1540             critical=fail,
1541             log=fail,
1542         ):
1543             ff(THIS_FILE)
1544
1545     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1546     def test_blackd_main(self) -> None:
1547         with patch("blackd.web.run_app"):
1548             result = CliRunner().invoke(blackd.main, [])
1549             if result.exception is not None:
1550                 raise result.exception
1551             self.assertEqual(result.exit_code, 0)
1552
1553
1554 class BlackDTestCase(AioHTTPTestCase):
1555     async def get_application(self) -> web.Application:
1556         return blackd.make_app()
1557
1558     # TODO: remove these decorators once the below is released
1559     # https://github.com/aio-libs/aiohttp/pull/3727
1560     @skip_if_exception("ClientOSError")
1561     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1562     @unittest_run_loop
1563     async def test_blackd_request_needs_formatting(self) -> None:
1564         response = await self.client.post("/", data=b"print('hello world')")
1565         self.assertEqual(response.status, 200)
1566         self.assertEqual(response.charset, "utf8")
1567         self.assertEqual(await response.read(), b'print("hello world")\n')
1568
1569     @skip_if_exception("ClientOSError")
1570     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1571     @unittest_run_loop
1572     async def test_blackd_request_no_change(self) -> None:
1573         response = await self.client.post("/", data=b'print("hello world")\n')
1574         self.assertEqual(response.status, 204)
1575         self.assertEqual(await response.read(), b"")
1576
1577     @skip_if_exception("ClientOSError")
1578     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1579     @unittest_run_loop
1580     async def test_blackd_request_syntax_error(self) -> None:
1581         response = await self.client.post("/", data=b"what even ( is")
1582         self.assertEqual(response.status, 400)
1583         content = await response.text()
1584         self.assertTrue(
1585             content.startswith("Cannot parse"),
1586             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1587         )
1588
1589     @skip_if_exception("ClientOSError")
1590     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1591     @unittest_run_loop
1592     async def test_blackd_unsupported_version(self) -> None:
1593         response = await self.client.post(
1594             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1595         )
1596         self.assertEqual(response.status, 501)
1597
1598     @skip_if_exception("ClientOSError")
1599     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1600     @unittest_run_loop
1601     async def test_blackd_supported_version(self) -> None:
1602         response = await self.client.post(
1603             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1604         )
1605         self.assertEqual(response.status, 200)
1606
1607     @skip_if_exception("ClientOSError")
1608     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1609     @unittest_run_loop
1610     async def test_blackd_invalid_python_variant(self) -> None:
1611         async def check(header_value: str, expected_status: int = 400) -> None:
1612             response = await self.client.post(
1613                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1614             )
1615             self.assertEqual(response.status, expected_status)
1616
1617         await check("lol")
1618         await check("ruby3.5")
1619         await check("pyi3.6")
1620         await check("py1.5")
1621         await check("2.8")
1622         await check("py2.8")
1623         await check("3.0")
1624         await check("pypy3.0")
1625         await check("jython3.4")
1626
1627     @skip_if_exception("ClientOSError")
1628     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1629     @unittest_run_loop
1630     async def test_blackd_pyi(self) -> None:
1631         source, expected = read_data("stub.pyi")
1632         response = await self.client.post(
1633             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1634         )
1635         self.assertEqual(response.status, 200)
1636         self.assertEqual(await response.text(), expected)
1637
1638     @skip_if_exception("ClientOSError")
1639     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1640     @unittest_run_loop
1641     async def test_blackd_python_variant(self) -> None:
1642         code = (
1643             "def f(\n"
1644             "    and_has_a_bunch_of,\n"
1645             "    very_long_arguments_too,\n"
1646             "    and_lots_of_them_as_well_lol,\n"
1647             "    **and_very_long_keyword_arguments\n"
1648             "):\n"
1649             "    pass\n"
1650         )
1651
1652         async def check(header_value: str, expected_status: int) -> None:
1653             response = await self.client.post(
1654                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1655             )
1656             self.assertEqual(
1657                 response.status, expected_status, msg=await response.text()
1658             )
1659
1660         await check("3.6", 200)
1661         await check("py3.6", 200)
1662         await check("3.6,3.7", 200)
1663         await check("3.6,py3.7", 200)
1664         await check("py36,py37", 200)
1665         await check("36", 200)
1666         await check("3.6.4", 200)
1667
1668         await check("2", 204)
1669         await check("2.7", 204)
1670         await check("py2.7", 204)
1671         await check("3.4", 204)
1672         await check("py3.4", 204)
1673         await check("py34,py36", 204)
1674         await check("34", 204)
1675
1676     @skip_if_exception("ClientOSError")
1677     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1678     @unittest_run_loop
1679     async def test_blackd_line_length(self) -> None:
1680         response = await self.client.post(
1681             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1682         )
1683         self.assertEqual(response.status, 200)
1684
1685     @skip_if_exception("ClientOSError")
1686     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1687     @unittest_run_loop
1688     async def test_blackd_invalid_line_length(self) -> None:
1689         response = await self.client.post(
1690             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
1691         )
1692         self.assertEqual(response.status, 400)
1693
1694     @skip_if_exception("ClientOSError")
1695     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1696     @unittest_run_loop
1697     async def test_blackd_response_black_version_header(self) -> None:
1698         response = await self.client.post("/")
1699         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
1700
1701
1702 if __name__ == "__main__":
1703     unittest.main(module="test_black")