]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Switch from versioneer to setuptools-scm (#1008)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator, TypeVar
14 import unittest
15 from unittest.mock import patch, MagicMock
16
17 from click import unstyle
18 from click.testing import CliRunner
19
20 import black
21 from black import Feature, TargetVersion
22
23 try:
24     import blackd
25     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
26     from aiohttp import web
27 except ImportError:
28     has_blackd_deps = False
29 else:
30     has_blackd_deps = True
31
32 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
33 fs = partial(black.format_str, mode=black.FileMode())
34 THIS_FILE = Path(__file__)
35 THIS_DIR = THIS_FILE.parent
36 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
37 PY36_ARGS = [
38     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
39 ]
40 T = TypeVar("T")
41 R = TypeVar("R")
42
43
44 def dump_to_stderr(*output: str) -> str:
45     return "\n" + "\n".join(output) + "\n"
46
47
48 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
49     """read_data('test_name') -> 'input', 'output'"""
50     if not name.endswith((".py", ".pyi", ".out", ".diff")):
51         name += ".py"
52     _input: List[str] = []
53     _output: List[str] = []
54     base_dir = THIS_DIR / "data" if data else THIS_DIR
55     with open(base_dir / name, "r", encoding="utf8") as test:
56         lines = test.readlines()
57     result = _input
58     for line in lines:
59         line = line.replace(EMPTY_LINE, "")
60         if line.rstrip() == "# output":
61             result = _output
62             continue
63
64         result.append(line)
65     if _input and not _output:
66         # If there's no output marker, treat the entire file as already pre-formatted.
67         _output = _input[:]
68     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
69
70
71 @contextmanager
72 def cache_dir(exists: bool = True) -> Iterator[Path]:
73     with TemporaryDirectory() as workspace:
74         cache_dir = Path(workspace)
75         if not exists:
76             cache_dir = cache_dir / "new"
77         with patch("black.CACHE_DIR", cache_dir):
78             yield cache_dir
79
80
81 @contextmanager
82 def event_loop(close: bool) -> Iterator[None]:
83     policy = asyncio.get_event_loop_policy()
84     loop = policy.new_event_loop()
85     asyncio.set_event_loop(loop)
86     try:
87         yield
88
89     finally:
90         if close:
91             loop.close()
92
93
94 @contextmanager
95 def skip_if_exception(e: str) -> Iterator[None]:
96     try:
97         yield
98     except Exception as exc:
99         if exc.__class__.__name__ == e:
100             unittest.skip(f"Encountered expected exception {exc}, skipping")
101         else:
102             raise
103
104
105 class BlackRunner(CliRunner):
106     """Modify CliRunner so that stderr is not merged with stdout.
107
108     This is a hack that can be removed once we depend on Click 7.x"""
109
110     def __init__(self) -> None:
111         self.stderrbuf = BytesIO()
112         self.stdoutbuf = BytesIO()
113         self.stdout_bytes = b""
114         self.stderr_bytes = b""
115         super().__init__()
116
117     @contextmanager
118     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
119         with super().isolation(*args, **kwargs) as output:
120             try:
121                 hold_stderr = sys.stderr
122                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
123                 yield output
124             finally:
125                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
126                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
127                 sys.stderr = hold_stderr
128
129
130 class BlackTestCase(unittest.TestCase):
131     maxDiff = None
132
133     def assertFormatEqual(self, expected: str, actual: str) -> None:
134         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
135             bdv: black.DebugVisitor[Any]
136             black.out("Expected tree:", fg="green")
137             try:
138                 exp_node = black.lib2to3_parse(expected)
139                 bdv = black.DebugVisitor()
140                 list(bdv.visit(exp_node))
141             except Exception as ve:
142                 black.err(str(ve))
143             black.out("Actual tree:", fg="red")
144             try:
145                 exp_node = black.lib2to3_parse(actual)
146                 bdv = black.DebugVisitor()
147                 list(bdv.visit(exp_node))
148             except Exception as ve:
149                 black.err(str(ve))
150         self.assertEqual(expected, actual)
151
152     def invokeBlack(
153         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
154     ) -> None:
155         runner = BlackRunner()
156         if ignore_config:
157             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
158         result = runner.invoke(black.main, args)
159         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
160
161     @patch("black.dump_to_file", dump_to_stderr)
162     def test_empty(self) -> None:
163         source = expected = ""
164         actual = fs(source)
165         self.assertFormatEqual(expected, actual)
166         black.assert_equivalent(source, actual)
167         black.assert_stable(source, actual, black.FileMode())
168
169     def test_empty_ff(self) -> None:
170         expected = ""
171         tmp_file = Path(black.dump_to_file())
172         try:
173             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
174             with open(tmp_file, encoding="utf8") as f:
175                 actual = f.read()
176         finally:
177             os.unlink(tmp_file)
178         self.assertFormatEqual(expected, actual)
179
180     @patch("black.dump_to_file", dump_to_stderr)
181     def test_self(self) -> None:
182         source, expected = read_data("test_black", data=False)
183         actual = fs(source)
184         self.assertFormatEqual(expected, actual)
185         black.assert_equivalent(source, actual)
186         black.assert_stable(source, actual, black.FileMode())
187         self.assertFalse(ff(THIS_FILE))
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_black(self) -> None:
191         source, expected = read_data("../black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
197
198     def test_piping(self) -> None:
199         source, expected = read_data("../black", data=False)
200         result = BlackRunner().invoke(
201             black.main,
202             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
203             input=BytesIO(source.encode("utf8")),
204         )
205         self.assertEqual(result.exit_code, 0)
206         self.assertFormatEqual(expected, result.output)
207         black.assert_equivalent(source, result.output)
208         black.assert_stable(source, result.output, black.FileMode())
209
210     def test_piping_diff(self) -> None:
211         diff_header = re.compile(
212             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
213             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
214         )
215         source, _ = read_data("expression.py")
216         expected, _ = read_data("expression.diff")
217         config = THIS_DIR / "data" / "empty_pyproject.toml"
218         args = [
219             "-",
220             "--fast",
221             f"--line-length={black.DEFAULT_LINE_LENGTH}",
222             "--diff",
223             f"--config={config}",
224         ]
225         result = BlackRunner().invoke(
226             black.main, args, input=BytesIO(source.encode("utf8"))
227         )
228         self.assertEqual(result.exit_code, 0)
229         actual = diff_header.sub("[Deterministic header]", result.output)
230         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
231         self.assertEqual(expected, actual)
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_setup(self) -> None:
235         source, expected = read_data("../setup", data=False)
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, black.FileMode())
240         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_function(self) -> None:
244         source, expected = read_data("function")
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_function2(self) -> None:
252         source, expected = read_data("function2")
253         actual = fs(source)
254         self.assertFormatEqual(expected, actual)
255         black.assert_equivalent(source, actual)
256         black.assert_stable(source, actual, black.FileMode())
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_function_trailing_comma(self) -> None:
260         source, expected = read_data("function_trailing_comma")
261         actual = fs(source)
262         self.assertFormatEqual(expected, actual)
263         black.assert_equivalent(source, actual)
264         black.assert_stable(source, actual, black.FileMode())
265
266     @patch("black.dump_to_file", dump_to_stderr)
267     def test_expression(self) -> None:
268         source, expected = read_data("expression")
269         actual = fs(source)
270         self.assertFormatEqual(expected, actual)
271         black.assert_equivalent(source, actual)
272         black.assert_stable(source, actual, black.FileMode())
273
274     @patch("black.dump_to_file", dump_to_stderr)
275     def test_pep_572(self) -> None:
276         source, expected = read_data("pep_572")
277         actual = fs(source)
278         self.assertFormatEqual(expected, actual)
279         black.assert_stable(source, actual, black.FileMode())
280         if sys.version_info >= (3, 8):
281             black.assert_equivalent(source, actual)
282
283     def test_pep_572_version_detection(self) -> None:
284         source, _ = read_data("pep_572")
285         root = black.lib2to3_parse(source)
286         features = black.get_features_used(root)
287         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
288         versions = black.detect_target_versions(root)
289         self.assertIn(black.TargetVersion.PY38, versions)
290
291     def test_expression_ff(self) -> None:
292         source, expected = read_data("expression")
293         tmp_file = Path(black.dump_to_file(source))
294         try:
295             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
296             with open(tmp_file, encoding="utf8") as f:
297                 actual = f.read()
298         finally:
299             os.unlink(tmp_file)
300         self.assertFormatEqual(expected, actual)
301         with patch("black.dump_to_file", dump_to_stderr):
302             black.assert_equivalent(source, actual)
303             black.assert_stable(source, actual, black.FileMode())
304
305     def test_expression_diff(self) -> None:
306         source, _ = read_data("expression.py")
307         expected, _ = read_data("expression.diff")
308         tmp_file = Path(black.dump_to_file(source))
309         diff_header = re.compile(
310             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
311             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
312         )
313         try:
314             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
315             self.assertEqual(result.exit_code, 0)
316         finally:
317             os.unlink(tmp_file)
318         actual = result.output
319         actual = diff_header.sub("[Deterministic header]", actual)
320         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
321         if expected != actual:
322             dump = black.dump_to_file(actual)
323             msg = (
324                 f"Expected diff isn't equal to the actual. If you made changes "
325                 f"to expression.py and this is an anticipated difference, "
326                 f"overwrite tests/data/expression.diff with {dump}"
327             )
328             self.assertEqual(expected, actual, msg)
329
330     @patch("black.dump_to_file", dump_to_stderr)
331     def test_fstring(self) -> None:
332         source, expected = read_data("fstring")
333         actual = fs(source)
334         self.assertFormatEqual(expected, actual)
335         black.assert_equivalent(source, actual)
336         black.assert_stable(source, actual, black.FileMode())
337
338     @patch("black.dump_to_file", dump_to_stderr)
339     def test_pep_570(self) -> None:
340         source, expected = read_data("pep_570")
341         actual = fs(source)
342         self.assertFormatEqual(expected, actual)
343         black.assert_stable(source, actual, black.FileMode())
344         if sys.version_info >= (3, 8):
345             black.assert_equivalent(source, actual)
346
347     def test_detect_pos_only_arguments(self) -> None:
348         source, _ = read_data("pep_570")
349         root = black.lib2to3_parse(source)
350         features = black.get_features_used(root)
351         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
352         versions = black.detect_target_versions(root)
353         self.assertIn(black.TargetVersion.PY38, versions)
354
355     @patch("black.dump_to_file", dump_to_stderr)
356     def test_string_quotes(self) -> None:
357         source, expected = read_data("string_quotes")
358         actual = fs(source)
359         self.assertFormatEqual(expected, actual)
360         black.assert_equivalent(source, actual)
361         black.assert_stable(source, actual, black.FileMode())
362         mode = black.FileMode(string_normalization=False)
363         not_normalized = fs(source, mode=mode)
364         self.assertFormatEqual(source, not_normalized)
365         black.assert_equivalent(source, not_normalized)
366         black.assert_stable(source, not_normalized, mode=mode)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_slices(self) -> None:
370         source, expected = read_data("slices")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, black.FileMode())
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_comments(self) -> None:
378         source, expected = read_data("comments")
379         actual = fs(source)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, black.FileMode())
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_comments2(self) -> None:
386         source, expected = read_data("comments2")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_equivalent(source, actual)
390         black.assert_stable(source, actual, black.FileMode())
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_comments3(self) -> None:
394         source, expected = read_data("comments3")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, black.FileMode())
399
400     @patch("black.dump_to_file", dump_to_stderr)
401     def test_comments4(self) -> None:
402         source, expected = read_data("comments4")
403         actual = fs(source)
404         self.assertFormatEqual(expected, actual)
405         black.assert_equivalent(source, actual)
406         black.assert_stable(source, actual, black.FileMode())
407
408     @patch("black.dump_to_file", dump_to_stderr)
409     def test_comments5(self) -> None:
410         source, expected = read_data("comments5")
411         actual = fs(source)
412         self.assertFormatEqual(expected, actual)
413         black.assert_equivalent(source, actual)
414         black.assert_stable(source, actual, black.FileMode())
415
416     @patch("black.dump_to_file", dump_to_stderr)
417     def test_comments6(self) -> None:
418         source, expected = read_data("comments6")
419         actual = fs(source)
420         self.assertFormatEqual(expected, actual)
421         black.assert_equivalent(source, actual)
422         black.assert_stable(source, actual, black.FileMode())
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_comments7(self) -> None:
426         source, expected = read_data("comments7")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, black.FileMode())
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_comment_after_escaped_newline(self) -> None:
434         source, expected = read_data("comment_after_escaped_newline")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, black.FileMode())
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_cantfit(self) -> None:
442         source, expected = read_data("cantfit")
443         actual = fs(source)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, black.FileMode())
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_import_spacing(self) -> None:
450         source, expected = read_data("import_spacing")
451         actual = fs(source)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_composition(self) -> None:
458         source, expected = read_data("composition")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_empty_lines(self) -> None:
466         source, expected = read_data("empty_lines")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_remove_parens(self) -> None:
474         source, expected = read_data("remove_parens")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_string_prefixes(self) -> None:
482         source, expected = read_data("string_prefixes")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, black.FileMode())
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_numeric_literals(self) -> None:
490         source, expected = read_data("numeric_literals")
491         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
492         actual = fs(source, mode=mode)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, mode)
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_numeric_literals_ignoring_underscores(self) -> None:
499         source, expected = read_data("numeric_literals_skip_underscores")
500         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
501         actual = fs(source, mode=mode)
502         self.assertFormatEqual(expected, actual)
503         black.assert_equivalent(source, actual)
504         black.assert_stable(source, actual, mode)
505
506     @patch("black.dump_to_file", dump_to_stderr)
507     def test_numeric_literals_py2(self) -> None:
508         source, expected = read_data("numeric_literals_py2")
509         actual = fs(source)
510         self.assertFormatEqual(expected, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_python2(self) -> None:
515         source, expected = read_data("python2")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_python2_print_function(self) -> None:
523         source, expected = read_data("python2_print_function")
524         mode = black.FileMode(target_versions={TargetVersion.PY27})
525         actual = fs(source, mode=mode)
526         self.assertFormatEqual(expected, actual)
527         black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, mode)
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_python2_unicode_literals(self) -> None:
532         source, expected = read_data("python2_unicode_literals")
533         actual = fs(source)
534         self.assertFormatEqual(expected, actual)
535         black.assert_equivalent(source, actual)
536         black.assert_stable(source, actual, black.FileMode())
537
538     @patch("black.dump_to_file", dump_to_stderr)
539     def test_stub(self) -> None:
540         mode = black.FileMode(is_pyi=True)
541         source, expected = read_data("stub.pyi")
542         actual = fs(source, mode=mode)
543         self.assertFormatEqual(expected, actual)
544         black.assert_stable(source, actual, mode)
545
546     @patch("black.dump_to_file", dump_to_stderr)
547     def test_async_as_identifier(self) -> None:
548         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
549         source, expected = read_data("async_as_identifier")
550         actual = fs(source)
551         self.assertFormatEqual(expected, actual)
552         major, minor = sys.version_info[:2]
553         if major < 3 or (major <= 3 and minor < 7):
554             black.assert_equivalent(source, actual)
555         black.assert_stable(source, actual, black.FileMode())
556         # ensure black can parse this when the target is 3.6
557         self.invokeBlack([str(source_path), "--target-version", "py36"])
558         # but not on 3.7, because async/await is no longer an identifier
559         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
560
561     @patch("black.dump_to_file", dump_to_stderr)
562     def test_python37(self) -> None:
563         source_path = (THIS_DIR / "data" / "python37.py").resolve()
564         source, expected = read_data("python37")
565         actual = fs(source)
566         self.assertFormatEqual(expected, actual)
567         major, minor = sys.version_info[:2]
568         if major > 3 or (major == 3 and minor >= 7):
569             black.assert_equivalent(source, actual)
570         black.assert_stable(source, actual, black.FileMode())
571         # ensure black can parse this when the target is 3.7
572         self.invokeBlack([str(source_path), "--target-version", "py37"])
573         # but not on 3.6, because we use async as a reserved keyword
574         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
575
576     @patch("black.dump_to_file", dump_to_stderr)
577     def test_fmtonoff(self) -> None:
578         source, expected = read_data("fmtonoff")
579         actual = fs(source)
580         self.assertFormatEqual(expected, actual)
581         black.assert_equivalent(source, actual)
582         black.assert_stable(source, actual, black.FileMode())
583
584     @patch("black.dump_to_file", dump_to_stderr)
585     def test_fmtonoff2(self) -> None:
586         source, expected = read_data("fmtonoff2")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         black.assert_equivalent(source, actual)
590         black.assert_stable(source, actual, black.FileMode())
591
592     @patch("black.dump_to_file", dump_to_stderr)
593     def test_remove_empty_parentheses_after_class(self) -> None:
594         source, expected = read_data("class_blank_parentheses")
595         actual = fs(source)
596         self.assertFormatEqual(expected, actual)
597         black.assert_equivalent(source, actual)
598         black.assert_stable(source, actual, black.FileMode())
599
600     @patch("black.dump_to_file", dump_to_stderr)
601     def test_new_line_between_class_and_code(self) -> None:
602         source, expected = read_data("class_methods_new_line")
603         actual = fs(source)
604         self.assertFormatEqual(expected, actual)
605         black.assert_equivalent(source, actual)
606         black.assert_stable(source, actual, black.FileMode())
607
608     @patch("black.dump_to_file", dump_to_stderr)
609     def test_bracket_match(self) -> None:
610         source, expected = read_data("bracketmatch")
611         actual = fs(source)
612         self.assertFormatEqual(expected, actual)
613         black.assert_equivalent(source, actual)
614         black.assert_stable(source, actual, black.FileMode())
615
616     @patch("black.dump_to_file", dump_to_stderr)
617     def test_tuple_assign(self) -> None:
618         source, expected = read_data("tupleassign")
619         actual = fs(source)
620         self.assertFormatEqual(expected, actual)
621         black.assert_equivalent(source, actual)
622         black.assert_stable(source, actual, black.FileMode())
623
624     @patch("black.dump_to_file", dump_to_stderr)
625     def test_beginning_backslash(self) -> None:
626         source, expected = read_data("beginning_backslash")
627         actual = fs(source)
628         self.assertFormatEqual(expected, actual)
629         black.assert_equivalent(source, actual)
630         black.assert_stable(source, actual, black.FileMode())
631
632     def test_tab_comment_indentation(self) -> None:
633         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
634         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
635         self.assertFormatEqual(contents_spc, fs(contents_spc))
636         self.assertFormatEqual(contents_spc, fs(contents_tab))
637
638         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
639         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
640         self.assertFormatEqual(contents_spc, fs(contents_spc))
641         self.assertFormatEqual(contents_spc, fs(contents_tab))
642
643         # mixed tabs and spaces (valid Python 2 code)
644         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
645         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
646         self.assertFormatEqual(contents_spc, fs(contents_spc))
647         self.assertFormatEqual(contents_spc, fs(contents_tab))
648
649         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
650         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
651         self.assertFormatEqual(contents_spc, fs(contents_spc))
652         self.assertFormatEqual(contents_spc, fs(contents_tab))
653
654     def test_report_verbose(self) -> None:
655         report = black.Report(verbose=True)
656         out_lines = []
657         err_lines = []
658
659         def out(msg: str, **kwargs: Any) -> None:
660             out_lines.append(msg)
661
662         def err(msg: str, **kwargs: Any) -> None:
663             err_lines.append(msg)
664
665         with patch("black.out", out), patch("black.err", err):
666             report.done(Path("f1"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 1)
668             self.assertEqual(len(err_lines), 0)
669             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
670             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
671             self.assertEqual(report.return_code, 0)
672             report.done(Path("f2"), black.Changed.YES)
673             self.assertEqual(len(out_lines), 2)
674             self.assertEqual(len(err_lines), 0)
675             self.assertEqual(out_lines[-1], "reformatted f2")
676             self.assertEqual(
677                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
678             )
679             report.done(Path("f3"), black.Changed.CACHED)
680             self.assertEqual(len(out_lines), 3)
681             self.assertEqual(len(err_lines), 0)
682             self.assertEqual(
683                 out_lines[-1], "f3 wasn't modified on disk since last run."
684             )
685             self.assertEqual(
686                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
687             )
688             self.assertEqual(report.return_code, 0)
689             report.check = True
690             self.assertEqual(report.return_code, 1)
691             report.check = False
692             report.failed(Path("e1"), "boom")
693             self.assertEqual(len(out_lines), 3)
694             self.assertEqual(len(err_lines), 1)
695             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
696             self.assertEqual(
697                 unstyle(str(report)),
698                 "1 file reformatted, 2 files left unchanged, "
699                 "1 file failed to reformat.",
700             )
701             self.assertEqual(report.return_code, 123)
702             report.done(Path("f3"), black.Changed.YES)
703             self.assertEqual(len(out_lines), 4)
704             self.assertEqual(len(err_lines), 1)
705             self.assertEqual(out_lines[-1], "reformatted f3")
706             self.assertEqual(
707                 unstyle(str(report)),
708                 "2 files reformatted, 2 files left unchanged, "
709                 "1 file failed to reformat.",
710             )
711             self.assertEqual(report.return_code, 123)
712             report.failed(Path("e2"), "boom")
713             self.assertEqual(len(out_lines), 4)
714             self.assertEqual(len(err_lines), 2)
715             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
716             self.assertEqual(
717                 unstyle(str(report)),
718                 "2 files reformatted, 2 files left unchanged, "
719                 "2 files failed to reformat.",
720             )
721             self.assertEqual(report.return_code, 123)
722             report.path_ignored(Path("wat"), "no match")
723             self.assertEqual(len(out_lines), 5)
724             self.assertEqual(len(err_lines), 2)
725             self.assertEqual(out_lines[-1], "wat ignored: no match")
726             self.assertEqual(
727                 unstyle(str(report)),
728                 "2 files reformatted, 2 files left unchanged, "
729                 "2 files failed to reformat.",
730             )
731             self.assertEqual(report.return_code, 123)
732             report.done(Path("f4"), black.Changed.NO)
733             self.assertEqual(len(out_lines), 6)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files reformatted, 3 files left unchanged, "
739                 "2 files failed to reformat.",
740             )
741             self.assertEqual(report.return_code, 123)
742             report.check = True
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files would be reformatted, 3 files would be left unchanged, "
746                 "2 files would fail to reformat.",
747             )
748
749     def test_report_quiet(self) -> None:
750         report = black.Report(quiet=True)
751         out_lines = []
752         err_lines = []
753
754         def out(msg: str, **kwargs: Any) -> None:
755             out_lines.append(msg)
756
757         def err(msg: str, **kwargs: Any) -> None:
758             err_lines.append(msg)
759
760         with patch("black.out", out), patch("black.err", err):
761             report.done(Path("f1"), black.Changed.NO)
762             self.assertEqual(len(out_lines), 0)
763             self.assertEqual(len(err_lines), 0)
764             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
765             self.assertEqual(report.return_code, 0)
766             report.done(Path("f2"), black.Changed.YES)
767             self.assertEqual(len(out_lines), 0)
768             self.assertEqual(len(err_lines), 0)
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
771             )
772             report.done(Path("f3"), black.Changed.CACHED)
773             self.assertEqual(len(out_lines), 0)
774             self.assertEqual(len(err_lines), 0)
775             self.assertEqual(
776                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
777             )
778             self.assertEqual(report.return_code, 0)
779             report.check = True
780             self.assertEqual(report.return_code, 1)
781             report.check = False
782             report.failed(Path("e1"), "boom")
783             self.assertEqual(len(out_lines), 0)
784             self.assertEqual(len(err_lines), 1)
785             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
786             self.assertEqual(
787                 unstyle(str(report)),
788                 "1 file reformatted, 2 files left unchanged, "
789                 "1 file failed to reformat.",
790             )
791             self.assertEqual(report.return_code, 123)
792             report.done(Path("f3"), black.Changed.YES)
793             self.assertEqual(len(out_lines), 0)
794             self.assertEqual(len(err_lines), 1)
795             self.assertEqual(
796                 unstyle(str(report)),
797                 "2 files reformatted, 2 files left unchanged, "
798                 "1 file failed to reformat.",
799             )
800             self.assertEqual(report.return_code, 123)
801             report.failed(Path("e2"), "boom")
802             self.assertEqual(len(out_lines), 0)
803             self.assertEqual(len(err_lines), 2)
804             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
805             self.assertEqual(
806                 unstyle(str(report)),
807                 "2 files reformatted, 2 files left unchanged, "
808                 "2 files failed to reformat.",
809             )
810             self.assertEqual(report.return_code, 123)
811             report.path_ignored(Path("wat"), "no match")
812             self.assertEqual(len(out_lines), 0)
813             self.assertEqual(len(err_lines), 2)
814             self.assertEqual(
815                 unstyle(str(report)),
816                 "2 files reformatted, 2 files left unchanged, "
817                 "2 files failed to reformat.",
818             )
819             self.assertEqual(report.return_code, 123)
820             report.done(Path("f4"), black.Changed.NO)
821             self.assertEqual(len(out_lines), 0)
822             self.assertEqual(len(err_lines), 2)
823             self.assertEqual(
824                 unstyle(str(report)),
825                 "2 files reformatted, 3 files left unchanged, "
826                 "2 files failed to reformat.",
827             )
828             self.assertEqual(report.return_code, 123)
829             report.check = True
830             self.assertEqual(
831                 unstyle(str(report)),
832                 "2 files would be reformatted, 3 files would be left unchanged, "
833                 "2 files would fail to reformat.",
834             )
835
836     def test_report_normal(self) -> None:
837         report = black.Report()
838         out_lines = []
839         err_lines = []
840
841         def out(msg: str, **kwargs: Any) -> None:
842             out_lines.append(msg)
843
844         def err(msg: str, **kwargs: Any) -> None:
845             err_lines.append(msg)
846
847         with patch("black.out", out), patch("black.err", err):
848             report.done(Path("f1"), black.Changed.NO)
849             self.assertEqual(len(out_lines), 0)
850             self.assertEqual(len(err_lines), 0)
851             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
852             self.assertEqual(report.return_code, 0)
853             report.done(Path("f2"), black.Changed.YES)
854             self.assertEqual(len(out_lines), 1)
855             self.assertEqual(len(err_lines), 0)
856             self.assertEqual(out_lines[-1], "reformatted f2")
857             self.assertEqual(
858                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
859             )
860             report.done(Path("f3"), black.Changed.CACHED)
861             self.assertEqual(len(out_lines), 1)
862             self.assertEqual(len(err_lines), 0)
863             self.assertEqual(out_lines[-1], "reformatted f2")
864             self.assertEqual(
865                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
866             )
867             self.assertEqual(report.return_code, 0)
868             report.check = True
869             self.assertEqual(report.return_code, 1)
870             report.check = False
871             report.failed(Path("e1"), "boom")
872             self.assertEqual(len(out_lines), 1)
873             self.assertEqual(len(err_lines), 1)
874             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
875             self.assertEqual(
876                 unstyle(str(report)),
877                 "1 file reformatted, 2 files left unchanged, "
878                 "1 file failed to reformat.",
879             )
880             self.assertEqual(report.return_code, 123)
881             report.done(Path("f3"), black.Changed.YES)
882             self.assertEqual(len(out_lines), 2)
883             self.assertEqual(len(err_lines), 1)
884             self.assertEqual(out_lines[-1], "reformatted f3")
885             self.assertEqual(
886                 unstyle(str(report)),
887                 "2 files reformatted, 2 files left unchanged, "
888                 "1 file failed to reformat.",
889             )
890             self.assertEqual(report.return_code, 123)
891             report.failed(Path("e2"), "boom")
892             self.assertEqual(len(out_lines), 2)
893             self.assertEqual(len(err_lines), 2)
894             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
895             self.assertEqual(
896                 unstyle(str(report)),
897                 "2 files reformatted, 2 files left unchanged, "
898                 "2 files failed to reformat.",
899             )
900             self.assertEqual(report.return_code, 123)
901             report.path_ignored(Path("wat"), "no match")
902             self.assertEqual(len(out_lines), 2)
903             self.assertEqual(len(err_lines), 2)
904             self.assertEqual(
905                 unstyle(str(report)),
906                 "2 files reformatted, 2 files left unchanged, "
907                 "2 files failed to reformat.",
908             )
909             self.assertEqual(report.return_code, 123)
910             report.done(Path("f4"), black.Changed.NO)
911             self.assertEqual(len(out_lines), 2)
912             self.assertEqual(len(err_lines), 2)
913             self.assertEqual(
914                 unstyle(str(report)),
915                 "2 files reformatted, 3 files left unchanged, "
916                 "2 files failed to reformat.",
917             )
918             self.assertEqual(report.return_code, 123)
919             report.check = True
920             self.assertEqual(
921                 unstyle(str(report)),
922                 "2 files would be reformatted, 3 files would be left unchanged, "
923                 "2 files would fail to reformat.",
924             )
925
926     def test_lib2to3_parse(self) -> None:
927         with self.assertRaises(black.InvalidInput):
928             black.lib2to3_parse("invalid syntax")
929
930         straddling = "x + y"
931         black.lib2to3_parse(straddling)
932         black.lib2to3_parse(straddling, {TargetVersion.PY27})
933         black.lib2to3_parse(straddling, {TargetVersion.PY36})
934         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
935
936         py2_only = "print x"
937         black.lib2to3_parse(py2_only)
938         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
939         with self.assertRaises(black.InvalidInput):
940             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
941         with self.assertRaises(black.InvalidInput):
942             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
943
944         py3_only = "exec(x, end=y)"
945         black.lib2to3_parse(py3_only)
946         with self.assertRaises(black.InvalidInput):
947             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
948         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
949         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
950
951     def test_get_features_used(self) -> None:
952         node = black.lib2to3_parse("def f(*, arg): ...\n")
953         self.assertEqual(black.get_features_used(node), set())
954         node = black.lib2to3_parse("def f(*, arg,): ...\n")
955         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
956         node = black.lib2to3_parse("f(*arg,)\n")
957         self.assertEqual(
958             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
959         )
960         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
961         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
962         node = black.lib2to3_parse("123_456\n")
963         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
964         node = black.lib2to3_parse("123456\n")
965         self.assertEqual(black.get_features_used(node), set())
966         source, expected = read_data("function")
967         node = black.lib2to3_parse(source)
968         expected_features = {
969             Feature.TRAILING_COMMA_IN_CALL,
970             Feature.TRAILING_COMMA_IN_DEF,
971             Feature.F_STRINGS,
972         }
973         self.assertEqual(black.get_features_used(node), expected_features)
974         node = black.lib2to3_parse(expected)
975         self.assertEqual(black.get_features_used(node), expected_features)
976         source, expected = read_data("expression")
977         node = black.lib2to3_parse(source)
978         self.assertEqual(black.get_features_used(node), set())
979         node = black.lib2to3_parse(expected)
980         self.assertEqual(black.get_features_used(node), set())
981
982     def test_get_future_imports(self) -> None:
983         node = black.lib2to3_parse("\n")
984         self.assertEqual(set(), black.get_future_imports(node))
985         node = black.lib2to3_parse("from __future__ import black\n")
986         self.assertEqual({"black"}, black.get_future_imports(node))
987         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
988         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
989         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
990         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
991         node = black.lib2to3_parse(
992             "from __future__ import multiple\nfrom __future__ import imports\n"
993         )
994         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
995         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
996         self.assertEqual({"black"}, black.get_future_imports(node))
997         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
998         self.assertEqual({"black"}, black.get_future_imports(node))
999         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1000         self.assertEqual(set(), black.get_future_imports(node))
1001         node = black.lib2to3_parse("from some.module import black\n")
1002         self.assertEqual(set(), black.get_future_imports(node))
1003         node = black.lib2to3_parse(
1004             "from __future__ import unicode_literals as _unicode_literals"
1005         )
1006         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1007         node = black.lib2to3_parse(
1008             "from __future__ import unicode_literals as _lol, print"
1009         )
1010         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1011
1012     def test_debug_visitor(self) -> None:
1013         source, _ = read_data("debug_visitor.py")
1014         expected, _ = read_data("debug_visitor.out")
1015         out_lines = []
1016         err_lines = []
1017
1018         def out(msg: str, **kwargs: Any) -> None:
1019             out_lines.append(msg)
1020
1021         def err(msg: str, **kwargs: Any) -> None:
1022             err_lines.append(msg)
1023
1024         with patch("black.out", out), patch("black.err", err):
1025             black.DebugVisitor.show(source)
1026         actual = "\n".join(out_lines) + "\n"
1027         log_name = ""
1028         if expected != actual:
1029             log_name = black.dump_to_file(*out_lines)
1030         self.assertEqual(
1031             expected,
1032             actual,
1033             f"AST print out is different. Actual version dumped to {log_name}",
1034         )
1035
1036     def test_format_file_contents(self) -> None:
1037         empty = ""
1038         mode = black.FileMode()
1039         with self.assertRaises(black.NothingChanged):
1040             black.format_file_contents(empty, mode=mode, fast=False)
1041         just_nl = "\n"
1042         with self.assertRaises(black.NothingChanged):
1043             black.format_file_contents(just_nl, mode=mode, fast=False)
1044         same = "j = [1, 2, 3]\n"
1045         with self.assertRaises(black.NothingChanged):
1046             black.format_file_contents(same, mode=mode, fast=False)
1047         different = "j = [1,2,3]"
1048         expected = same
1049         actual = black.format_file_contents(different, mode=mode, fast=False)
1050         self.assertEqual(expected, actual)
1051         invalid = "return if you can"
1052         with self.assertRaises(black.InvalidInput) as e:
1053             black.format_file_contents(invalid, mode=mode, fast=False)
1054         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1055
1056     def test_endmarker(self) -> None:
1057         n = black.lib2to3_parse("\n")
1058         self.assertEqual(n.type, black.syms.file_input)
1059         self.assertEqual(len(n.children), 1)
1060         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1061
1062     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1063     def test_assertFormatEqual(self) -> None:
1064         out_lines = []
1065         err_lines = []
1066
1067         def out(msg: str, **kwargs: Any) -> None:
1068             out_lines.append(msg)
1069
1070         def err(msg: str, **kwargs: Any) -> None:
1071             err_lines.append(msg)
1072
1073         with patch("black.out", out), patch("black.err", err):
1074             with self.assertRaises(AssertionError):
1075                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1076
1077         out_str = "".join(out_lines)
1078         self.assertTrue("Expected tree:" in out_str)
1079         self.assertTrue("Actual tree:" in out_str)
1080         self.assertEqual("".join(err_lines), "")
1081
1082     def test_cache_broken_file(self) -> None:
1083         mode = black.FileMode()
1084         with cache_dir() as workspace:
1085             cache_file = black.get_cache_file(mode)
1086             with cache_file.open("w") as fobj:
1087                 fobj.write("this is not a pickle")
1088             self.assertEqual(black.read_cache(mode), {})
1089             src = (workspace / "test.py").resolve()
1090             with src.open("w") as fobj:
1091                 fobj.write("print('hello')")
1092             self.invokeBlack([str(src)])
1093             cache = black.read_cache(mode)
1094             self.assertIn(src, cache)
1095
1096     def test_cache_single_file_already_cached(self) -> None:
1097         mode = black.FileMode()
1098         with cache_dir() as workspace:
1099             src = (workspace / "test.py").resolve()
1100             with src.open("w") as fobj:
1101                 fobj.write("print('hello')")
1102             black.write_cache({}, [src], mode)
1103             self.invokeBlack([str(src)])
1104             with src.open("r") as fobj:
1105                 self.assertEqual(fobj.read(), "print('hello')")
1106
1107     @event_loop(close=False)
1108     def test_cache_multiple_files(self) -> None:
1109         mode = black.FileMode()
1110         with cache_dir() as workspace, patch(
1111             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1112         ):
1113             one = (workspace / "one.py").resolve()
1114             with one.open("w") as fobj:
1115                 fobj.write("print('hello')")
1116             two = (workspace / "two.py").resolve()
1117             with two.open("w") as fobj:
1118                 fobj.write("print('hello')")
1119             black.write_cache({}, [one], mode)
1120             self.invokeBlack([str(workspace)])
1121             with one.open("r") as fobj:
1122                 self.assertEqual(fobj.read(), "print('hello')")
1123             with two.open("r") as fobj:
1124                 self.assertEqual(fobj.read(), 'print("hello")\n')
1125             cache = black.read_cache(mode)
1126             self.assertIn(one, cache)
1127             self.assertIn(two, cache)
1128
1129     def test_no_cache_when_writeback_diff(self) -> None:
1130         mode = black.FileMode()
1131         with cache_dir() as workspace:
1132             src = (workspace / "test.py").resolve()
1133             with src.open("w") as fobj:
1134                 fobj.write("print('hello')")
1135             self.invokeBlack([str(src), "--diff"])
1136             cache_file = black.get_cache_file(mode)
1137             self.assertFalse(cache_file.exists())
1138
1139     def test_no_cache_when_stdin(self) -> None:
1140         mode = black.FileMode()
1141         with cache_dir():
1142             result = CliRunner().invoke(
1143                 black.main, ["-"], input=BytesIO(b"print('hello')")
1144             )
1145             self.assertEqual(result.exit_code, 0)
1146             cache_file = black.get_cache_file(mode)
1147             self.assertFalse(cache_file.exists())
1148
1149     def test_read_cache_no_cachefile(self) -> None:
1150         mode = black.FileMode()
1151         with cache_dir():
1152             self.assertEqual(black.read_cache(mode), {})
1153
1154     def test_write_cache_read_cache(self) -> None:
1155         mode = black.FileMode()
1156         with cache_dir() as workspace:
1157             src = (workspace / "test.py").resolve()
1158             src.touch()
1159             black.write_cache({}, [src], mode)
1160             cache = black.read_cache(mode)
1161             self.assertIn(src, cache)
1162             self.assertEqual(cache[src], black.get_cache_info(src))
1163
1164     def test_filter_cached(self) -> None:
1165         with TemporaryDirectory() as workspace:
1166             path = Path(workspace)
1167             uncached = (path / "uncached").resolve()
1168             cached = (path / "cached").resolve()
1169             cached_but_changed = (path / "changed").resolve()
1170             uncached.touch()
1171             cached.touch()
1172             cached_but_changed.touch()
1173             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1174             todo, done = black.filter_cached(
1175                 cache, {uncached, cached, cached_but_changed}
1176             )
1177             self.assertEqual(todo, {uncached, cached_but_changed})
1178             self.assertEqual(done, {cached})
1179
1180     def test_write_cache_creates_directory_if_needed(self) -> None:
1181         mode = black.FileMode()
1182         with cache_dir(exists=False) as workspace:
1183             self.assertFalse(workspace.exists())
1184             black.write_cache({}, [], mode)
1185             self.assertTrue(workspace.exists())
1186
1187     @event_loop(close=False)
1188     def test_failed_formatting_does_not_get_cached(self) -> None:
1189         mode = black.FileMode()
1190         with cache_dir() as workspace, patch(
1191             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1192         ):
1193             failing = (workspace / "failing.py").resolve()
1194             with failing.open("w") as fobj:
1195                 fobj.write("not actually python")
1196             clean = (workspace / "clean.py").resolve()
1197             with clean.open("w") as fobj:
1198                 fobj.write('print("hello")\n')
1199             self.invokeBlack([str(workspace)], exit_code=123)
1200             cache = black.read_cache(mode)
1201             self.assertNotIn(failing, cache)
1202             self.assertIn(clean, cache)
1203
1204     def test_write_cache_write_fail(self) -> None:
1205         mode = black.FileMode()
1206         with cache_dir(), patch.object(Path, "open") as mock:
1207             mock.side_effect = OSError
1208             black.write_cache({}, [], mode)
1209
1210     @event_loop(close=False)
1211     def test_check_diff_use_together(self) -> None:
1212         with cache_dir():
1213             # Files which will be reformatted.
1214             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1215             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1216             # Files which will not be reformatted.
1217             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1218             self.invokeBlack([str(src2), "--diff", "--check"])
1219             # Multi file command.
1220             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1221
1222     def test_no_files(self) -> None:
1223         with cache_dir():
1224             # Without an argument, black exits with error code 0.
1225             self.invokeBlack([])
1226
1227     def test_broken_symlink(self) -> None:
1228         with cache_dir() as workspace:
1229             symlink = workspace / "broken_link.py"
1230             try:
1231                 symlink.symlink_to("nonexistent.py")
1232             except OSError as e:
1233                 self.skipTest(f"Can't create symlinks: {e}")
1234             self.invokeBlack([str(workspace.resolve())])
1235
1236     def test_read_cache_line_lengths(self) -> None:
1237         mode = black.FileMode()
1238         short_mode = black.FileMode(line_length=1)
1239         with cache_dir() as workspace:
1240             path = (workspace / "file.py").resolve()
1241             path.touch()
1242             black.write_cache({}, [path], mode)
1243             one = black.read_cache(mode)
1244             self.assertIn(path, one)
1245             two = black.read_cache(short_mode)
1246             self.assertNotIn(path, two)
1247
1248     def test_single_file_force_pyi(self) -> None:
1249         reg_mode = black.FileMode()
1250         pyi_mode = black.FileMode(is_pyi=True)
1251         contents, expected = read_data("force_pyi")
1252         with cache_dir() as workspace:
1253             path = (workspace / "file.py").resolve()
1254             with open(path, "w") as fh:
1255                 fh.write(contents)
1256             self.invokeBlack([str(path), "--pyi"])
1257             with open(path, "r") as fh:
1258                 actual = fh.read()
1259             # verify cache with --pyi is separate
1260             pyi_cache = black.read_cache(pyi_mode)
1261             self.assertIn(path, pyi_cache)
1262             normal_cache = black.read_cache(reg_mode)
1263             self.assertNotIn(path, normal_cache)
1264         self.assertEqual(actual, expected)
1265
1266     @event_loop(close=False)
1267     def test_multi_file_force_pyi(self) -> None:
1268         reg_mode = black.FileMode()
1269         pyi_mode = black.FileMode(is_pyi=True)
1270         contents, expected = read_data("force_pyi")
1271         with cache_dir() as workspace:
1272             paths = [
1273                 (workspace / "file1.py").resolve(),
1274                 (workspace / "file2.py").resolve(),
1275             ]
1276             for path in paths:
1277                 with open(path, "w") as fh:
1278                     fh.write(contents)
1279             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1280             for path in paths:
1281                 with open(path, "r") as fh:
1282                     actual = fh.read()
1283                 self.assertEqual(actual, expected)
1284             # verify cache with --pyi is separate
1285             pyi_cache = black.read_cache(pyi_mode)
1286             normal_cache = black.read_cache(reg_mode)
1287             for path in paths:
1288                 self.assertIn(path, pyi_cache)
1289                 self.assertNotIn(path, normal_cache)
1290
1291     def test_pipe_force_pyi(self) -> None:
1292         source, expected = read_data("force_pyi")
1293         result = CliRunner().invoke(
1294             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1295         )
1296         self.assertEqual(result.exit_code, 0)
1297         actual = result.output
1298         self.assertFormatEqual(actual, expected)
1299
1300     def test_single_file_force_py36(self) -> None:
1301         reg_mode = black.FileMode()
1302         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1303         source, expected = read_data("force_py36")
1304         with cache_dir() as workspace:
1305             path = (workspace / "file.py").resolve()
1306             with open(path, "w") as fh:
1307                 fh.write(source)
1308             self.invokeBlack([str(path), *PY36_ARGS])
1309             with open(path, "r") as fh:
1310                 actual = fh.read()
1311             # verify cache with --target-version is separate
1312             py36_cache = black.read_cache(py36_mode)
1313             self.assertIn(path, py36_cache)
1314             normal_cache = black.read_cache(reg_mode)
1315             self.assertNotIn(path, normal_cache)
1316         self.assertEqual(actual, expected)
1317
1318     @event_loop(close=False)
1319     def test_multi_file_force_py36(self) -> None:
1320         reg_mode = black.FileMode()
1321         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1322         source, expected = read_data("force_py36")
1323         with cache_dir() as workspace:
1324             paths = [
1325                 (workspace / "file1.py").resolve(),
1326                 (workspace / "file2.py").resolve(),
1327             ]
1328             for path in paths:
1329                 with open(path, "w") as fh:
1330                     fh.write(source)
1331             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1332             for path in paths:
1333                 with open(path, "r") as fh:
1334                     actual = fh.read()
1335                 self.assertEqual(actual, expected)
1336             # verify cache with --target-version is separate
1337             pyi_cache = black.read_cache(py36_mode)
1338             normal_cache = black.read_cache(reg_mode)
1339             for path in paths:
1340                 self.assertIn(path, pyi_cache)
1341                 self.assertNotIn(path, normal_cache)
1342
1343     def test_pipe_force_py36(self) -> None:
1344         source, expected = read_data("force_py36")
1345         result = CliRunner().invoke(
1346             black.main,
1347             ["-", "-q", "--target-version=py36"],
1348             input=BytesIO(source.encode("utf8")),
1349         )
1350         self.assertEqual(result.exit_code, 0)
1351         actual = result.output
1352         self.assertFormatEqual(actual, expected)
1353
1354     def test_include_exclude(self) -> None:
1355         path = THIS_DIR / "data" / "include_exclude_tests"
1356         include = re.compile(r"\.pyi?$")
1357         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1358         report = black.Report()
1359         sources: List[Path] = []
1360         expected = [
1361             Path(path / "b/dont_exclude/a.py"),
1362             Path(path / "b/dont_exclude/a.pyi"),
1363         ]
1364         this_abs = THIS_DIR.resolve()
1365         sources.extend(
1366             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1367         )
1368         self.assertEqual(sorted(expected), sorted(sources))
1369
1370     def test_empty_include(self) -> None:
1371         path = THIS_DIR / "data" / "include_exclude_tests"
1372         report = black.Report()
1373         empty = re.compile(r"")
1374         sources: List[Path] = []
1375         expected = [
1376             Path(path / "b/exclude/a.pie"),
1377             Path(path / "b/exclude/a.py"),
1378             Path(path / "b/exclude/a.pyi"),
1379             Path(path / "b/dont_exclude/a.pie"),
1380             Path(path / "b/dont_exclude/a.py"),
1381             Path(path / "b/dont_exclude/a.pyi"),
1382             Path(path / "b/.definitely_exclude/a.pie"),
1383             Path(path / "b/.definitely_exclude/a.py"),
1384             Path(path / "b/.definitely_exclude/a.pyi"),
1385         ]
1386         this_abs = THIS_DIR.resolve()
1387         sources.extend(
1388             black.gen_python_files_in_dir(
1389                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1390             )
1391         )
1392         self.assertEqual(sorted(expected), sorted(sources))
1393
1394     def test_empty_exclude(self) -> None:
1395         path = THIS_DIR / "data" / "include_exclude_tests"
1396         report = black.Report()
1397         empty = re.compile(r"")
1398         sources: List[Path] = []
1399         expected = [
1400             Path(path / "b/dont_exclude/a.py"),
1401             Path(path / "b/dont_exclude/a.pyi"),
1402             Path(path / "b/exclude/a.py"),
1403             Path(path / "b/exclude/a.pyi"),
1404             Path(path / "b/.definitely_exclude/a.py"),
1405             Path(path / "b/.definitely_exclude/a.pyi"),
1406         ]
1407         this_abs = THIS_DIR.resolve()
1408         sources.extend(
1409             black.gen_python_files_in_dir(
1410                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1411             )
1412         )
1413         self.assertEqual(sorted(expected), sorted(sources))
1414
1415     def test_invalid_include_exclude(self) -> None:
1416         for option in ["--include", "--exclude"]:
1417             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1418
1419     def test_preserves_line_endings(self) -> None:
1420         with TemporaryDirectory() as workspace:
1421             test_file = Path(workspace) / "test.py"
1422             for nl in ["\n", "\r\n"]:
1423                 contents = nl.join(["def f(  ):", "    pass"])
1424                 test_file.write_bytes(contents.encode())
1425                 ff(test_file, write_back=black.WriteBack.YES)
1426                 updated_contents: bytes = test_file.read_bytes()
1427                 self.assertIn(nl.encode(), updated_contents)
1428                 if nl == "\n":
1429                     self.assertNotIn(b"\r\n", updated_contents)
1430
1431     def test_preserves_line_endings_via_stdin(self) -> None:
1432         for nl in ["\n", "\r\n"]:
1433             contents = nl.join(["def f(  ):", "    pass"])
1434             runner = BlackRunner()
1435             result = runner.invoke(
1436                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1437             )
1438             self.assertEqual(result.exit_code, 0)
1439             output = runner.stdout_bytes
1440             self.assertIn(nl.encode("utf8"), output)
1441             if nl == "\n":
1442                 self.assertNotIn(b"\r\n", output)
1443
1444     def test_assert_equivalent_different_asts(self) -> None:
1445         with self.assertRaises(AssertionError):
1446             black.assert_equivalent("{}", "None")
1447
1448     def test_symlink_out_of_root_directory(self) -> None:
1449         path = MagicMock()
1450         root = THIS_DIR
1451         child = MagicMock()
1452         include = re.compile(black.DEFAULT_INCLUDES)
1453         exclude = re.compile(black.DEFAULT_EXCLUDES)
1454         report = black.Report()
1455         # `child` should behave like a symlink which resolved path is clearly
1456         # outside of the `root` directory.
1457         path.iterdir.return_value = [child]
1458         child.resolve.return_value = Path("/a/b/c")
1459         child.is_symlink.return_value = True
1460         try:
1461             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1462         except ValueError as ve:
1463             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1464         path.iterdir.assert_called_once()
1465         child.resolve.assert_called_once()
1466         child.is_symlink.assert_called_once()
1467         # `child` should behave like a strange file which resolved path is clearly
1468         # outside of the `root` directory.
1469         child.is_symlink.return_value = False
1470         with self.assertRaises(ValueError):
1471             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1472         path.iterdir.assert_called()
1473         self.assertEqual(path.iterdir.call_count, 2)
1474         child.resolve.assert_called()
1475         self.assertEqual(child.resolve.call_count, 2)
1476         child.is_symlink.assert_called()
1477         self.assertEqual(child.is_symlink.call_count, 2)
1478
1479     def test_shhh_click(self) -> None:
1480         try:
1481             from click import _unicodefun  # type: ignore
1482         except ModuleNotFoundError:
1483             self.skipTest("Incompatible Click version")
1484         if not hasattr(_unicodefun, "_verify_python3_env"):
1485             self.skipTest("Incompatible Click version")
1486         # First, let's see if Click is crashing with a preferred ASCII charset.
1487         with patch("locale.getpreferredencoding") as gpe:
1488             gpe.return_value = "ASCII"
1489             with self.assertRaises(RuntimeError):
1490                 _unicodefun._verify_python3_env()
1491         # Now, let's silence Click...
1492         black.patch_click()
1493         # ...and confirm it's silent.
1494         with patch("locale.getpreferredencoding") as gpe:
1495             gpe.return_value = "ASCII"
1496             try:
1497                 _unicodefun._verify_python3_env()
1498             except RuntimeError as re:
1499                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1500
1501     def test_root_logger_not_used_directly(self) -> None:
1502         def fail(*args: Any, **kwargs: Any) -> None:
1503             self.fail("Record created with root logger")
1504
1505         with patch.multiple(
1506             logging.root,
1507             debug=fail,
1508             info=fail,
1509             warning=fail,
1510             error=fail,
1511             critical=fail,
1512             log=fail,
1513         ):
1514             ff(THIS_FILE)
1515
1516     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1517     def test_blackd_main(self) -> None:
1518         with patch("blackd.web.run_app"):
1519             result = CliRunner().invoke(blackd.main, [])
1520             if result.exception is not None:
1521                 raise result.exception
1522             self.assertEqual(result.exit_code, 0)
1523
1524
1525 class BlackDTestCase(AioHTTPTestCase):
1526     async def get_application(self) -> web.Application:
1527         return blackd.make_app()
1528
1529     # TODO: remove these decorators once the below is released
1530     # https://github.com/aio-libs/aiohttp/pull/3727
1531     @skip_if_exception("ClientOSError")
1532     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1533     @unittest_run_loop
1534     async def test_blackd_request_needs_formatting(self) -> None:
1535         response = await self.client.post("/", data=b"print('hello world')")
1536         self.assertEqual(response.status, 200)
1537         self.assertEqual(response.charset, "utf8")
1538         self.assertEqual(await response.read(), b'print("hello world")\n')
1539
1540     @skip_if_exception("ClientOSError")
1541     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1542     @unittest_run_loop
1543     async def test_blackd_request_no_change(self) -> None:
1544         response = await self.client.post("/", data=b'print("hello world")\n')
1545         self.assertEqual(response.status, 204)
1546         self.assertEqual(await response.read(), b"")
1547
1548     @skip_if_exception("ClientOSError")
1549     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1550     @unittest_run_loop
1551     async def test_blackd_request_syntax_error(self) -> None:
1552         response = await self.client.post("/", data=b"what even ( is")
1553         self.assertEqual(response.status, 400)
1554         content = await response.text()
1555         self.assertTrue(
1556             content.startswith("Cannot parse"),
1557             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1558         )
1559
1560     @skip_if_exception("ClientOSError")
1561     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1562     @unittest_run_loop
1563     async def test_blackd_unsupported_version(self) -> None:
1564         response = await self.client.post(
1565             "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1566         )
1567         self.assertEqual(response.status, 501)
1568
1569     @skip_if_exception("ClientOSError")
1570     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1571     @unittest_run_loop
1572     async def test_blackd_supported_version(self) -> None:
1573         response = await self.client.post(
1574             "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1575         )
1576         self.assertEqual(response.status, 200)
1577
1578     @skip_if_exception("ClientOSError")
1579     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1580     @unittest_run_loop
1581     async def test_blackd_invalid_python_variant(self) -> None:
1582         async def check(header_value: str, expected_status: int = 400) -> None:
1583             response = await self.client.post(
1584                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1585             )
1586             self.assertEqual(response.status, expected_status)
1587
1588         await check("lol")
1589         await check("ruby3.5")
1590         await check("pyi3.6")
1591         await check("py1.5")
1592         await check("2.8")
1593         await check("py2.8")
1594         await check("3.0")
1595         await check("pypy3.0")
1596         await check("jython3.4")
1597
1598     @skip_if_exception("ClientOSError")
1599     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1600     @unittest_run_loop
1601     async def test_blackd_pyi(self) -> None:
1602         source, expected = read_data("stub.pyi")
1603         response = await self.client.post(
1604             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1605         )
1606         self.assertEqual(response.status, 200)
1607         self.assertEqual(await response.text(), expected)
1608
1609     @skip_if_exception("ClientOSError")
1610     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1611     @unittest_run_loop
1612     async def test_blackd_python_variant(self) -> None:
1613         code = (
1614             "def f(\n"
1615             "    and_has_a_bunch_of,\n"
1616             "    very_long_arguments_too,\n"
1617             "    and_lots_of_them_as_well_lol,\n"
1618             "    **and_very_long_keyword_arguments\n"
1619             "):\n"
1620             "    pass\n"
1621         )
1622
1623         async def check(header_value: str, expected_status: int) -> None:
1624             response = await self.client.post(
1625                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1626             )
1627             self.assertEqual(
1628                 response.status, expected_status, msg=await response.text()
1629             )
1630
1631         await check("3.6", 200)
1632         await check("py3.6", 200)
1633         await check("3.6,3.7", 200)
1634         await check("3.6,py3.7", 200)
1635         await check("py36,py37", 200)
1636         await check("36", 200)
1637         await check("3.6.4", 200)
1638
1639         await check("2", 204)
1640         await check("2.7", 204)
1641         await check("py2.7", 204)
1642         await check("3.4", 204)
1643         await check("py3.4", 204)
1644         await check("py34,py36", 204)
1645         await check("34", 204)
1646
1647     @skip_if_exception("ClientOSError")
1648     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1649     @unittest_run_loop
1650     async def test_blackd_line_length(self) -> None:
1651         response = await self.client.post(
1652             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1653         )
1654         self.assertEqual(response.status, 200)
1655
1656     @skip_if_exception("ClientOSError")
1657     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1658     @unittest_run_loop
1659     async def test_blackd_invalid_line_length(self) -> None:
1660         response = await self.client.post(
1661             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
1662         )
1663         self.assertEqual(response.status, 400)
1664
1665
1666 if __name__ == "__main__":
1667     unittest.main(module="test_black")