]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add support for walrus operator (#935)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_function_trailing_comma(self) -> None:
269         source, expected = read_data("function_trailing_comma")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_expression(self) -> None:
277         source, expected = read_data("expression")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_equivalent(source, actual)
281         black.assert_stable(source, actual, black.FileMode())
282
283     @patch("black.dump_to_file", dump_to_stderr)
284     def test_pep_572(self) -> None:
285         source, expected = read_data("pep_572")
286         actual = fs(source)
287         self.assertFormatEqual(expected, actual)
288         black.assert_stable(source, actual, black.FileMode())
289         if sys.version_info >= (3, 8):
290             black.assert_equivalent(source, actual)
291
292     def test_pep_572_version_detection(self) -> None:
293         source, _ = read_data("pep_572")
294         root = black.lib2to3_parse(source)
295         features = black.get_features_used(root)
296         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
297         versions = black.detect_target_versions(root)
298         self.assertIn(black.TargetVersion.PY38, versions)
299
300     def test_expression_ff(self) -> None:
301         source, expected = read_data("expression")
302         tmp_file = Path(black.dump_to_file(source))
303         try:
304             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
305             with open(tmp_file, encoding="utf8") as f:
306                 actual = f.read()
307         finally:
308             os.unlink(tmp_file)
309         self.assertFormatEqual(expected, actual)
310         with patch("black.dump_to_file", dump_to_stderr):
311             black.assert_equivalent(source, actual)
312             black.assert_stable(source, actual, black.FileMode())
313
314     def test_expression_diff(self) -> None:
315         source, _ = read_data("expression.py")
316         expected, _ = read_data("expression.diff")
317         tmp_file = Path(black.dump_to_file(source))
318         diff_header = re.compile(
319             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
320             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
321         )
322         try:
323             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
324             self.assertEqual(result.exit_code, 0)
325         finally:
326             os.unlink(tmp_file)
327         actual = result.output
328         actual = diff_header.sub("[Deterministic header]", actual)
329         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
330         if expected != actual:
331             dump = black.dump_to_file(actual)
332             msg = (
333                 f"Expected diff isn't equal to the actual. If you made changes "
334                 f"to expression.py and this is an anticipated difference, "
335                 f"overwrite tests/data/expression.diff with {dump}"
336             )
337             self.assertEqual(expected, actual, msg)
338
339     @patch("black.dump_to_file", dump_to_stderr)
340     def test_fstring(self) -> None:
341         source, expected = read_data("fstring")
342         actual = fs(source)
343         self.assertFormatEqual(expected, actual)
344         black.assert_equivalent(source, actual)
345         black.assert_stable(source, actual, black.FileMode())
346
347     @patch("black.dump_to_file", dump_to_stderr)
348     def test_string_quotes(self) -> None:
349         source, expected = read_data("string_quotes")
350         actual = fs(source)
351         self.assertFormatEqual(expected, actual)
352         black.assert_equivalent(source, actual)
353         black.assert_stable(source, actual, black.FileMode())
354         mode = black.FileMode(string_normalization=False)
355         not_normalized = fs(source, mode=mode)
356         self.assertFormatEqual(source, not_normalized)
357         black.assert_equivalent(source, not_normalized)
358         black.assert_stable(source, not_normalized, mode=mode)
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_slices(self) -> None:
362         source, expected = read_data("slices")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         black.assert_equivalent(source, actual)
366         black.assert_stable(source, actual, black.FileMode())
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_comments(self) -> None:
370         source, expected = read_data("comments")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, black.FileMode())
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_comments2(self) -> None:
378         source, expected = read_data("comments2")
379         actual = fs(source)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, black.FileMode())
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_comments3(self) -> None:
386         source, expected = read_data("comments3")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_equivalent(source, actual)
390         black.assert_stable(source, actual, black.FileMode())
391
392     @patch("black.dump_to_file", dump_to_stderr)
393     def test_comments4(self) -> None:
394         source, expected = read_data("comments4")
395         actual = fs(source)
396         self.assertFormatEqual(expected, actual)
397         black.assert_equivalent(source, actual)
398         black.assert_stable(source, actual, black.FileMode())
399
400     @patch("black.dump_to_file", dump_to_stderr)
401     def test_comments5(self) -> None:
402         source, expected = read_data("comments5")
403         actual = fs(source)
404         self.assertFormatEqual(expected, actual)
405         black.assert_equivalent(source, actual)
406         black.assert_stable(source, actual, black.FileMode())
407
408     @patch("black.dump_to_file", dump_to_stderr)
409     def test_comments6(self) -> None:
410         source, expected = read_data("comments6")
411         actual = fs(source)
412         self.assertFormatEqual(expected, actual)
413         black.assert_equivalent(source, actual)
414         black.assert_stable(source, actual, black.FileMode())
415
416     @patch("black.dump_to_file", dump_to_stderr)
417     def test_comments7(self) -> None:
418         source, expected = read_data("comments7")
419         actual = fs(source)
420         self.assertFormatEqual(expected, actual)
421         black.assert_equivalent(source, actual)
422         black.assert_stable(source, actual, black.FileMode())
423
424     @patch("black.dump_to_file", dump_to_stderr)
425     def test_comment_after_escaped_newline(self) -> None:
426         source, expected = read_data("comment_after_escaped_newline")
427         actual = fs(source)
428         self.assertFormatEqual(expected, actual)
429         black.assert_equivalent(source, actual)
430         black.assert_stable(source, actual, black.FileMode())
431
432     @patch("black.dump_to_file", dump_to_stderr)
433     def test_cantfit(self) -> None:
434         source, expected = read_data("cantfit")
435         actual = fs(source)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, black.FileMode())
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_import_spacing(self) -> None:
442         source, expected = read_data("import_spacing")
443         actual = fs(source)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, black.FileMode())
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_composition(self) -> None:
450         source, expected = read_data("composition")
451         actual = fs(source)
452         self.assertFormatEqual(expected, actual)
453         black.assert_equivalent(source, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_empty_lines(self) -> None:
458         source, expected = read_data("empty_lines")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_remove_parens(self) -> None:
466         source, expected = read_data("remove_parens")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_string_prefixes(self) -> None:
474         source, expected = read_data("string_prefixes")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_equivalent(source, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_numeric_literals(self) -> None:
482         source, expected = read_data("numeric_literals")
483         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
484         actual = fs(source, mode=mode)
485         self.assertFormatEqual(expected, actual)
486         black.assert_equivalent(source, actual)
487         black.assert_stable(source, actual, mode)
488
489     @patch("black.dump_to_file", dump_to_stderr)
490     def test_numeric_literals_ignoring_underscores(self) -> None:
491         source, expected = read_data("numeric_literals_skip_underscores")
492         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
493         actual = fs(source, mode=mode)
494         self.assertFormatEqual(expected, actual)
495         black.assert_equivalent(source, actual)
496         black.assert_stable(source, actual, mode)
497
498     @patch("black.dump_to_file", dump_to_stderr)
499     def test_numeric_literals_py2(self) -> None:
500         source, expected = read_data("numeric_literals_py2")
501         actual = fs(source)
502         self.assertFormatEqual(expected, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_python2(self) -> None:
507         source, expected = read_data("python2")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_python2_print_function(self) -> None:
515         source, expected = read_data("python2_print_function")
516         mode = black.FileMode(target_versions={TargetVersion.PY27})
517         actual = fs(source, mode=mode)
518         self.assertFormatEqual(expected, actual)
519         black.assert_equivalent(source, actual)
520         black.assert_stable(source, actual, mode)
521
522     @patch("black.dump_to_file", dump_to_stderr)
523     def test_python2_unicode_literals(self) -> None:
524         source, expected = read_data("python2_unicode_literals")
525         actual = fs(source)
526         self.assertFormatEqual(expected, actual)
527         black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, black.FileMode())
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_stub(self) -> None:
532         mode = black.FileMode(is_pyi=True)
533         source, expected = read_data("stub.pyi")
534         actual = fs(source, mode=mode)
535         self.assertFormatEqual(expected, actual)
536         black.assert_stable(source, actual, mode)
537
538     @patch("black.dump_to_file", dump_to_stderr)
539     def test_async_as_identifier(self) -> None:
540         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
541         source, expected = read_data("async_as_identifier")
542         actual = fs(source)
543         self.assertFormatEqual(expected, actual)
544         major, minor = sys.version_info[:2]
545         if major < 3 or (major <= 3 and minor < 7):
546             black.assert_equivalent(source, actual)
547         black.assert_stable(source, actual, black.FileMode())
548         # ensure black can parse this when the target is 3.6
549         self.invokeBlack([str(source_path), "--target-version", "py36"])
550         # but not on 3.7, because async/await is no longer an identifier
551         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
552
553     @patch("black.dump_to_file", dump_to_stderr)
554     def test_python37(self) -> None:
555         source_path = (THIS_DIR / "data" / "python37.py").resolve()
556         source, expected = read_data("python37")
557         actual = fs(source)
558         self.assertFormatEqual(expected, actual)
559         major, minor = sys.version_info[:2]
560         if major > 3 or (major == 3 and minor >= 7):
561             black.assert_equivalent(source, actual)
562         black.assert_stable(source, actual, black.FileMode())
563         # ensure black can parse this when the target is 3.7
564         self.invokeBlack([str(source_path), "--target-version", "py37"])
565         # but not on 3.6, because we use async as a reserved keyword
566         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
567
568     @patch("black.dump_to_file", dump_to_stderr)
569     def test_fmtonoff(self) -> None:
570         source, expected = read_data("fmtonoff")
571         actual = fs(source)
572         self.assertFormatEqual(expected, actual)
573         black.assert_equivalent(source, actual)
574         black.assert_stable(source, actual, black.FileMode())
575
576     @patch("black.dump_to_file", dump_to_stderr)
577     def test_fmtonoff2(self) -> None:
578         source, expected = read_data("fmtonoff2")
579         actual = fs(source)
580         self.assertFormatEqual(expected, actual)
581         black.assert_equivalent(source, actual)
582         black.assert_stable(source, actual, black.FileMode())
583
584     @patch("black.dump_to_file", dump_to_stderr)
585     def test_remove_empty_parentheses_after_class(self) -> None:
586         source, expected = read_data("class_blank_parentheses")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         black.assert_equivalent(source, actual)
590         black.assert_stable(source, actual, black.FileMode())
591
592     @patch("black.dump_to_file", dump_to_stderr)
593     def test_new_line_between_class_and_code(self) -> None:
594         source, expected = read_data("class_methods_new_line")
595         actual = fs(source)
596         self.assertFormatEqual(expected, actual)
597         black.assert_equivalent(source, actual)
598         black.assert_stable(source, actual, black.FileMode())
599
600     @patch("black.dump_to_file", dump_to_stderr)
601     def test_bracket_match(self) -> None:
602         source, expected = read_data("bracketmatch")
603         actual = fs(source)
604         self.assertFormatEqual(expected, actual)
605         black.assert_equivalent(source, actual)
606         black.assert_stable(source, actual, black.FileMode())
607
608     @patch("black.dump_to_file", dump_to_stderr)
609     def test_tuple_assign(self) -> None:
610         source, expected = read_data("tupleassign")
611         actual = fs(source)
612         self.assertFormatEqual(expected, actual)
613         black.assert_equivalent(source, actual)
614         black.assert_stable(source, actual, black.FileMode())
615
616     def test_tab_comment_indentation(self) -> None:
617         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
618         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
619         self.assertFormatEqual(contents_spc, fs(contents_spc))
620         self.assertFormatEqual(contents_spc, fs(contents_tab))
621
622         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
623         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
624         self.assertFormatEqual(contents_spc, fs(contents_spc))
625         self.assertFormatEqual(contents_spc, fs(contents_tab))
626
627         # mixed tabs and spaces (valid Python 2 code)
628         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
629         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
630         self.assertFormatEqual(contents_spc, fs(contents_spc))
631         self.assertFormatEqual(contents_spc, fs(contents_tab))
632
633         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
634         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
635         self.assertFormatEqual(contents_spc, fs(contents_spc))
636         self.assertFormatEqual(contents_spc, fs(contents_tab))
637
638     def test_report_verbose(self) -> None:
639         report = black.Report(verbose=True)
640         out_lines = []
641         err_lines = []
642
643         def out(msg: str, **kwargs: Any) -> None:
644             out_lines.append(msg)
645
646         def err(msg: str, **kwargs: Any) -> None:
647             err_lines.append(msg)
648
649         with patch("black.out", out), patch("black.err", err):
650             report.done(Path("f1"), black.Changed.NO)
651             self.assertEqual(len(out_lines), 1)
652             self.assertEqual(len(err_lines), 0)
653             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
654             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
655             self.assertEqual(report.return_code, 0)
656             report.done(Path("f2"), black.Changed.YES)
657             self.assertEqual(len(out_lines), 2)
658             self.assertEqual(len(err_lines), 0)
659             self.assertEqual(out_lines[-1], "reformatted f2")
660             self.assertEqual(
661                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
662             )
663             report.done(Path("f3"), black.Changed.CACHED)
664             self.assertEqual(len(out_lines), 3)
665             self.assertEqual(len(err_lines), 0)
666             self.assertEqual(
667                 out_lines[-1], "f3 wasn't modified on disk since last run."
668             )
669             self.assertEqual(
670                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
671             )
672             self.assertEqual(report.return_code, 0)
673             report.check = True
674             self.assertEqual(report.return_code, 1)
675             report.check = False
676             report.failed(Path("e1"), "boom")
677             self.assertEqual(len(out_lines), 3)
678             self.assertEqual(len(err_lines), 1)
679             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "1 file reformatted, 2 files left unchanged, "
683                 "1 file failed to reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.done(Path("f3"), black.Changed.YES)
687             self.assertEqual(len(out_lines), 4)
688             self.assertEqual(len(err_lines), 1)
689             self.assertEqual(out_lines[-1], "reformatted f3")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, "
693                 "1 file failed to reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.failed(Path("e2"), "boom")
697             self.assertEqual(len(out_lines), 4)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 2 files left unchanged, "
703                 "2 files failed to reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.path_ignored(Path("wat"), "no match")
707             self.assertEqual(len(out_lines), 5)
708             self.assertEqual(len(err_lines), 2)
709             self.assertEqual(out_lines[-1], "wat ignored: no match")
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files reformatted, 2 files left unchanged, "
713                 "2 files failed to reformat.",
714             )
715             self.assertEqual(report.return_code, 123)
716             report.done(Path("f4"), black.Changed.NO)
717             self.assertEqual(len(out_lines), 6)
718             self.assertEqual(len(err_lines), 2)
719             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
720             self.assertEqual(
721                 unstyle(str(report)),
722                 "2 files reformatted, 3 files left unchanged, "
723                 "2 files failed to reformat.",
724             )
725             self.assertEqual(report.return_code, 123)
726             report.check = True
727             self.assertEqual(
728                 unstyle(str(report)),
729                 "2 files would be reformatted, 3 files would be left unchanged, "
730                 "2 files would fail to reformat.",
731             )
732
733     def test_report_quiet(self) -> None:
734         report = black.Report(quiet=True)
735         out_lines = []
736         err_lines = []
737
738         def out(msg: str, **kwargs: Any) -> None:
739             out_lines.append(msg)
740
741         def err(msg: str, **kwargs: Any) -> None:
742             err_lines.append(msg)
743
744         with patch("black.out", out), patch("black.err", err):
745             report.done(Path("f1"), black.Changed.NO)
746             self.assertEqual(len(out_lines), 0)
747             self.assertEqual(len(err_lines), 0)
748             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
749             self.assertEqual(report.return_code, 0)
750             report.done(Path("f2"), black.Changed.YES)
751             self.assertEqual(len(out_lines), 0)
752             self.assertEqual(len(err_lines), 0)
753             self.assertEqual(
754                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
755             )
756             report.done(Path("f3"), black.Changed.CACHED)
757             self.assertEqual(len(out_lines), 0)
758             self.assertEqual(len(err_lines), 0)
759             self.assertEqual(
760                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
761             )
762             self.assertEqual(report.return_code, 0)
763             report.check = True
764             self.assertEqual(report.return_code, 1)
765             report.check = False
766             report.failed(Path("e1"), "boom")
767             self.assertEqual(len(out_lines), 0)
768             self.assertEqual(len(err_lines), 1)
769             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
770             self.assertEqual(
771                 unstyle(str(report)),
772                 "1 file reformatted, 2 files left unchanged, "
773                 "1 file failed to reformat.",
774             )
775             self.assertEqual(report.return_code, 123)
776             report.done(Path("f3"), black.Changed.YES)
777             self.assertEqual(len(out_lines), 0)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(
780                 unstyle(str(report)),
781                 "2 files reformatted, 2 files left unchanged, "
782                 "1 file failed to reformat.",
783             )
784             self.assertEqual(report.return_code, 123)
785             report.failed(Path("e2"), "boom")
786             self.assertEqual(len(out_lines), 0)
787             self.assertEqual(len(err_lines), 2)
788             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
789             self.assertEqual(
790                 unstyle(str(report)),
791                 "2 files reformatted, 2 files left unchanged, "
792                 "2 files failed to reformat.",
793             )
794             self.assertEqual(report.return_code, 123)
795             report.path_ignored(Path("wat"), "no match")
796             self.assertEqual(len(out_lines), 0)
797             self.assertEqual(len(err_lines), 2)
798             self.assertEqual(
799                 unstyle(str(report)),
800                 "2 files reformatted, 2 files left unchanged, "
801                 "2 files failed to reformat.",
802             )
803             self.assertEqual(report.return_code, 123)
804             report.done(Path("f4"), black.Changed.NO)
805             self.assertEqual(len(out_lines), 0)
806             self.assertEqual(len(err_lines), 2)
807             self.assertEqual(
808                 unstyle(str(report)),
809                 "2 files reformatted, 3 files left unchanged, "
810                 "2 files failed to reformat.",
811             )
812             self.assertEqual(report.return_code, 123)
813             report.check = True
814             self.assertEqual(
815                 unstyle(str(report)),
816                 "2 files would be reformatted, 3 files would be left unchanged, "
817                 "2 files would fail to reformat.",
818             )
819
820     def test_report_normal(self) -> None:
821         report = black.Report()
822         out_lines = []
823         err_lines = []
824
825         def out(msg: str, **kwargs: Any) -> None:
826             out_lines.append(msg)
827
828         def err(msg: str, **kwargs: Any) -> None:
829             err_lines.append(msg)
830
831         with patch("black.out", out), patch("black.err", err):
832             report.done(Path("f1"), black.Changed.NO)
833             self.assertEqual(len(out_lines), 0)
834             self.assertEqual(len(err_lines), 0)
835             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
836             self.assertEqual(report.return_code, 0)
837             report.done(Path("f2"), black.Changed.YES)
838             self.assertEqual(len(out_lines), 1)
839             self.assertEqual(len(err_lines), 0)
840             self.assertEqual(out_lines[-1], "reformatted f2")
841             self.assertEqual(
842                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
843             )
844             report.done(Path("f3"), black.Changed.CACHED)
845             self.assertEqual(len(out_lines), 1)
846             self.assertEqual(len(err_lines), 0)
847             self.assertEqual(out_lines[-1], "reformatted f2")
848             self.assertEqual(
849                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
850             )
851             self.assertEqual(report.return_code, 0)
852             report.check = True
853             self.assertEqual(report.return_code, 1)
854             report.check = False
855             report.failed(Path("e1"), "boom")
856             self.assertEqual(len(out_lines), 1)
857             self.assertEqual(len(err_lines), 1)
858             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
859             self.assertEqual(
860                 unstyle(str(report)),
861                 "1 file reformatted, 2 files left unchanged, "
862                 "1 file failed to reformat.",
863             )
864             self.assertEqual(report.return_code, 123)
865             report.done(Path("f3"), black.Changed.YES)
866             self.assertEqual(len(out_lines), 2)
867             self.assertEqual(len(err_lines), 1)
868             self.assertEqual(out_lines[-1], "reformatted f3")
869             self.assertEqual(
870                 unstyle(str(report)),
871                 "2 files reformatted, 2 files left unchanged, "
872                 "1 file failed to reformat.",
873             )
874             self.assertEqual(report.return_code, 123)
875             report.failed(Path("e2"), "boom")
876             self.assertEqual(len(out_lines), 2)
877             self.assertEqual(len(err_lines), 2)
878             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
879             self.assertEqual(
880                 unstyle(str(report)),
881                 "2 files reformatted, 2 files left unchanged, "
882                 "2 files failed to reformat.",
883             )
884             self.assertEqual(report.return_code, 123)
885             report.path_ignored(Path("wat"), "no match")
886             self.assertEqual(len(out_lines), 2)
887             self.assertEqual(len(err_lines), 2)
888             self.assertEqual(
889                 unstyle(str(report)),
890                 "2 files reformatted, 2 files left unchanged, "
891                 "2 files failed to reformat.",
892             )
893             self.assertEqual(report.return_code, 123)
894             report.done(Path("f4"), black.Changed.NO)
895             self.assertEqual(len(out_lines), 2)
896             self.assertEqual(len(err_lines), 2)
897             self.assertEqual(
898                 unstyle(str(report)),
899                 "2 files reformatted, 3 files left unchanged, "
900                 "2 files failed to reformat.",
901             )
902             self.assertEqual(report.return_code, 123)
903             report.check = True
904             self.assertEqual(
905                 unstyle(str(report)),
906                 "2 files would be reformatted, 3 files would be left unchanged, "
907                 "2 files would fail to reformat.",
908             )
909
910     def test_lib2to3_parse(self) -> None:
911         with self.assertRaises(black.InvalidInput):
912             black.lib2to3_parse("invalid syntax")
913
914         straddling = "x + y"
915         black.lib2to3_parse(straddling)
916         black.lib2to3_parse(straddling, {TargetVersion.PY27})
917         black.lib2to3_parse(straddling, {TargetVersion.PY36})
918         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
919
920         py2_only = "print x"
921         black.lib2to3_parse(py2_only)
922         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
923         with self.assertRaises(black.InvalidInput):
924             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
925         with self.assertRaises(black.InvalidInput):
926             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
927
928         py3_only = "exec(x, end=y)"
929         black.lib2to3_parse(py3_only)
930         with self.assertRaises(black.InvalidInput):
931             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
932         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
933         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
934
935     def test_get_features_used(self) -> None:
936         node = black.lib2to3_parse("def f(*, arg): ...\n")
937         self.assertEqual(black.get_features_used(node), set())
938         node = black.lib2to3_parse("def f(*, arg,): ...\n")
939         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
940         node = black.lib2to3_parse("f(*arg,)\n")
941         self.assertEqual(
942             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
943         )
944         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
945         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
946         node = black.lib2to3_parse("123_456\n")
947         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
948         node = black.lib2to3_parse("123456\n")
949         self.assertEqual(black.get_features_used(node), set())
950         source, expected = read_data("function")
951         node = black.lib2to3_parse(source)
952         expected_features = {
953             Feature.TRAILING_COMMA_IN_CALL,
954             Feature.TRAILING_COMMA_IN_DEF,
955             Feature.F_STRINGS,
956         }
957         self.assertEqual(black.get_features_used(node), expected_features)
958         node = black.lib2to3_parse(expected)
959         self.assertEqual(black.get_features_used(node), expected_features)
960         source, expected = read_data("expression")
961         node = black.lib2to3_parse(source)
962         self.assertEqual(black.get_features_used(node), set())
963         node = black.lib2to3_parse(expected)
964         self.assertEqual(black.get_features_used(node), set())
965
966     def test_get_future_imports(self) -> None:
967         node = black.lib2to3_parse("\n")
968         self.assertEqual(set(), black.get_future_imports(node))
969         node = black.lib2to3_parse("from __future__ import black\n")
970         self.assertEqual({"black"}, black.get_future_imports(node))
971         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
972         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
973         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
974         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
975         node = black.lib2to3_parse(
976             "from __future__ import multiple\nfrom __future__ import imports\n"
977         )
978         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
979         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
980         self.assertEqual({"black"}, black.get_future_imports(node))
981         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
982         self.assertEqual({"black"}, black.get_future_imports(node))
983         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
984         self.assertEqual(set(), black.get_future_imports(node))
985         node = black.lib2to3_parse("from some.module import black\n")
986         self.assertEqual(set(), black.get_future_imports(node))
987         node = black.lib2to3_parse(
988             "from __future__ import unicode_literals as _unicode_literals"
989         )
990         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
991         node = black.lib2to3_parse(
992             "from __future__ import unicode_literals as _lol, print"
993         )
994         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
995
996     def test_debug_visitor(self) -> None:
997         source, _ = read_data("debug_visitor.py")
998         expected, _ = read_data("debug_visitor.out")
999         out_lines = []
1000         err_lines = []
1001
1002         def out(msg: str, **kwargs: Any) -> None:
1003             out_lines.append(msg)
1004
1005         def err(msg: str, **kwargs: Any) -> None:
1006             err_lines.append(msg)
1007
1008         with patch("black.out", out), patch("black.err", err):
1009             black.DebugVisitor.show(source)
1010         actual = "\n".join(out_lines) + "\n"
1011         log_name = ""
1012         if expected != actual:
1013             log_name = black.dump_to_file(*out_lines)
1014         self.assertEqual(
1015             expected,
1016             actual,
1017             f"AST print out is different. Actual version dumped to {log_name}",
1018         )
1019
1020     def test_format_file_contents(self) -> None:
1021         empty = ""
1022         mode = black.FileMode()
1023         with self.assertRaises(black.NothingChanged):
1024             black.format_file_contents(empty, mode=mode, fast=False)
1025         just_nl = "\n"
1026         with self.assertRaises(black.NothingChanged):
1027             black.format_file_contents(just_nl, mode=mode, fast=False)
1028         same = "l = [1, 2, 3]\n"
1029         with self.assertRaises(black.NothingChanged):
1030             black.format_file_contents(same, mode=mode, fast=False)
1031         different = "l = [1,2,3]"
1032         expected = same
1033         actual = black.format_file_contents(different, mode=mode, fast=False)
1034         self.assertEqual(expected, actual)
1035         invalid = "return if you can"
1036         with self.assertRaises(black.InvalidInput) as e:
1037             black.format_file_contents(invalid, mode=mode, fast=False)
1038         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1039
1040     def test_endmarker(self) -> None:
1041         n = black.lib2to3_parse("\n")
1042         self.assertEqual(n.type, black.syms.file_input)
1043         self.assertEqual(len(n.children), 1)
1044         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1045
1046     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1047     def test_assertFormatEqual(self) -> None:
1048         out_lines = []
1049         err_lines = []
1050
1051         def out(msg: str, **kwargs: Any) -> None:
1052             out_lines.append(msg)
1053
1054         def err(msg: str, **kwargs: Any) -> None:
1055             err_lines.append(msg)
1056
1057         with patch("black.out", out), patch("black.err", err):
1058             with self.assertRaises(AssertionError):
1059                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1060
1061         out_str = "".join(out_lines)
1062         self.assertTrue("Expected tree:" in out_str)
1063         self.assertTrue("Actual tree:" in out_str)
1064         self.assertEqual("".join(err_lines), "")
1065
1066     def test_cache_broken_file(self) -> None:
1067         mode = black.FileMode()
1068         with cache_dir() as workspace:
1069             cache_file = black.get_cache_file(mode)
1070             with cache_file.open("w") as fobj:
1071                 fobj.write("this is not a pickle")
1072             self.assertEqual(black.read_cache(mode), {})
1073             src = (workspace / "test.py").resolve()
1074             with src.open("w") as fobj:
1075                 fobj.write("print('hello')")
1076             self.invokeBlack([str(src)])
1077             cache = black.read_cache(mode)
1078             self.assertIn(src, cache)
1079
1080     def test_cache_single_file_already_cached(self) -> None:
1081         mode = black.FileMode()
1082         with cache_dir() as workspace:
1083             src = (workspace / "test.py").resolve()
1084             with src.open("w") as fobj:
1085                 fobj.write("print('hello')")
1086             black.write_cache({}, [src], mode)
1087             self.invokeBlack([str(src)])
1088             with src.open("r") as fobj:
1089                 self.assertEqual(fobj.read(), "print('hello')")
1090
1091     @event_loop(close=False)
1092     def test_cache_multiple_files(self) -> None:
1093         mode = black.FileMode()
1094         with cache_dir() as workspace, patch(
1095             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1096         ):
1097             one = (workspace / "one.py").resolve()
1098             with one.open("w") as fobj:
1099                 fobj.write("print('hello')")
1100             two = (workspace / "two.py").resolve()
1101             with two.open("w") as fobj:
1102                 fobj.write("print('hello')")
1103             black.write_cache({}, [one], mode)
1104             self.invokeBlack([str(workspace)])
1105             with one.open("r") as fobj:
1106                 self.assertEqual(fobj.read(), "print('hello')")
1107             with two.open("r") as fobj:
1108                 self.assertEqual(fobj.read(), 'print("hello")\n')
1109             cache = black.read_cache(mode)
1110             self.assertIn(one, cache)
1111             self.assertIn(two, cache)
1112
1113     def test_no_cache_when_writeback_diff(self) -> None:
1114         mode = black.FileMode()
1115         with cache_dir() as workspace:
1116             src = (workspace / "test.py").resolve()
1117             with src.open("w") as fobj:
1118                 fobj.write("print('hello')")
1119             self.invokeBlack([str(src), "--diff"])
1120             cache_file = black.get_cache_file(mode)
1121             self.assertFalse(cache_file.exists())
1122
1123     def test_no_cache_when_stdin(self) -> None:
1124         mode = black.FileMode()
1125         with cache_dir():
1126             result = CliRunner().invoke(
1127                 black.main, ["-"], input=BytesIO(b"print('hello')")
1128             )
1129             self.assertEqual(result.exit_code, 0)
1130             cache_file = black.get_cache_file(mode)
1131             self.assertFalse(cache_file.exists())
1132
1133     def test_read_cache_no_cachefile(self) -> None:
1134         mode = black.FileMode()
1135         with cache_dir():
1136             self.assertEqual(black.read_cache(mode), {})
1137
1138     def test_write_cache_read_cache(self) -> None:
1139         mode = black.FileMode()
1140         with cache_dir() as workspace:
1141             src = (workspace / "test.py").resolve()
1142             src.touch()
1143             black.write_cache({}, [src], mode)
1144             cache = black.read_cache(mode)
1145             self.assertIn(src, cache)
1146             self.assertEqual(cache[src], black.get_cache_info(src))
1147
1148     def test_filter_cached(self) -> None:
1149         with TemporaryDirectory() as workspace:
1150             path = Path(workspace)
1151             uncached = (path / "uncached").resolve()
1152             cached = (path / "cached").resolve()
1153             cached_but_changed = (path / "changed").resolve()
1154             uncached.touch()
1155             cached.touch()
1156             cached_but_changed.touch()
1157             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1158             todo, done = black.filter_cached(
1159                 cache, {uncached, cached, cached_but_changed}
1160             )
1161             self.assertEqual(todo, {uncached, cached_but_changed})
1162             self.assertEqual(done, {cached})
1163
1164     def test_write_cache_creates_directory_if_needed(self) -> None:
1165         mode = black.FileMode()
1166         with cache_dir(exists=False) as workspace:
1167             self.assertFalse(workspace.exists())
1168             black.write_cache({}, [], mode)
1169             self.assertTrue(workspace.exists())
1170
1171     @event_loop(close=False)
1172     def test_failed_formatting_does_not_get_cached(self) -> None:
1173         mode = black.FileMode()
1174         with cache_dir() as workspace, patch(
1175             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1176         ):
1177             failing = (workspace / "failing.py").resolve()
1178             with failing.open("w") as fobj:
1179                 fobj.write("not actually python")
1180             clean = (workspace / "clean.py").resolve()
1181             with clean.open("w") as fobj:
1182                 fobj.write('print("hello")\n')
1183             self.invokeBlack([str(workspace)], exit_code=123)
1184             cache = black.read_cache(mode)
1185             self.assertNotIn(failing, cache)
1186             self.assertIn(clean, cache)
1187
1188     def test_write_cache_write_fail(self) -> None:
1189         mode = black.FileMode()
1190         with cache_dir(), patch.object(Path, "open") as mock:
1191             mock.side_effect = OSError
1192             black.write_cache({}, [], mode)
1193
1194     @event_loop(close=False)
1195     def test_check_diff_use_together(self) -> None:
1196         with cache_dir():
1197             # Files which will be reformatted.
1198             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1199             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1200             # Files which will not be reformatted.
1201             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1202             self.invokeBlack([str(src2), "--diff", "--check"])
1203             # Multi file command.
1204             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1205
1206     def test_no_files(self) -> None:
1207         with cache_dir():
1208             # Without an argument, black exits with error code 0.
1209             self.invokeBlack([])
1210
1211     def test_broken_symlink(self) -> None:
1212         with cache_dir() as workspace:
1213             symlink = workspace / "broken_link.py"
1214             try:
1215                 symlink.symlink_to("nonexistent.py")
1216             except OSError as e:
1217                 self.skipTest(f"Can't create symlinks: {e}")
1218             self.invokeBlack([str(workspace.resolve())])
1219
1220     def test_read_cache_line_lengths(self) -> None:
1221         mode = black.FileMode()
1222         short_mode = black.FileMode(line_length=1)
1223         with cache_dir() as workspace:
1224             path = (workspace / "file.py").resolve()
1225             path.touch()
1226             black.write_cache({}, [path], mode)
1227             one = black.read_cache(mode)
1228             self.assertIn(path, one)
1229             two = black.read_cache(short_mode)
1230             self.assertNotIn(path, two)
1231
1232     def test_single_file_force_pyi(self) -> None:
1233         reg_mode = black.FileMode()
1234         pyi_mode = black.FileMode(is_pyi=True)
1235         contents, expected = read_data("force_pyi")
1236         with cache_dir() as workspace:
1237             path = (workspace / "file.py").resolve()
1238             with open(path, "w") as fh:
1239                 fh.write(contents)
1240             self.invokeBlack([str(path), "--pyi"])
1241             with open(path, "r") as fh:
1242                 actual = fh.read()
1243             # verify cache with --pyi is separate
1244             pyi_cache = black.read_cache(pyi_mode)
1245             self.assertIn(path, pyi_cache)
1246             normal_cache = black.read_cache(reg_mode)
1247             self.assertNotIn(path, normal_cache)
1248         self.assertEqual(actual, expected)
1249
1250     @event_loop(close=False)
1251     def test_multi_file_force_pyi(self) -> None:
1252         reg_mode = black.FileMode()
1253         pyi_mode = black.FileMode(is_pyi=True)
1254         contents, expected = read_data("force_pyi")
1255         with cache_dir() as workspace:
1256             paths = [
1257                 (workspace / "file1.py").resolve(),
1258                 (workspace / "file2.py").resolve(),
1259             ]
1260             for path in paths:
1261                 with open(path, "w") as fh:
1262                     fh.write(contents)
1263             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1264             for path in paths:
1265                 with open(path, "r") as fh:
1266                     actual = fh.read()
1267                 self.assertEqual(actual, expected)
1268             # verify cache with --pyi is separate
1269             pyi_cache = black.read_cache(pyi_mode)
1270             normal_cache = black.read_cache(reg_mode)
1271             for path in paths:
1272                 self.assertIn(path, pyi_cache)
1273                 self.assertNotIn(path, normal_cache)
1274
1275     def test_pipe_force_pyi(self) -> None:
1276         source, expected = read_data("force_pyi")
1277         result = CliRunner().invoke(
1278             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1279         )
1280         self.assertEqual(result.exit_code, 0)
1281         actual = result.output
1282         self.assertFormatEqual(actual, expected)
1283
1284     def test_single_file_force_py36(self) -> None:
1285         reg_mode = black.FileMode()
1286         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1287         source, expected = read_data("force_py36")
1288         with cache_dir() as workspace:
1289             path = (workspace / "file.py").resolve()
1290             with open(path, "w") as fh:
1291                 fh.write(source)
1292             self.invokeBlack([str(path), *PY36_ARGS])
1293             with open(path, "r") as fh:
1294                 actual = fh.read()
1295             # verify cache with --target-version is separate
1296             py36_cache = black.read_cache(py36_mode)
1297             self.assertIn(path, py36_cache)
1298             normal_cache = black.read_cache(reg_mode)
1299             self.assertNotIn(path, normal_cache)
1300         self.assertEqual(actual, expected)
1301
1302     @event_loop(close=False)
1303     def test_multi_file_force_py36(self) -> None:
1304         reg_mode = black.FileMode()
1305         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1306         source, expected = read_data("force_py36")
1307         with cache_dir() as workspace:
1308             paths = [
1309                 (workspace / "file1.py").resolve(),
1310                 (workspace / "file2.py").resolve(),
1311             ]
1312             for path in paths:
1313                 with open(path, "w") as fh:
1314                     fh.write(source)
1315             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1316             for path in paths:
1317                 with open(path, "r") as fh:
1318                     actual = fh.read()
1319                 self.assertEqual(actual, expected)
1320             # verify cache with --target-version is separate
1321             pyi_cache = black.read_cache(py36_mode)
1322             normal_cache = black.read_cache(reg_mode)
1323             for path in paths:
1324                 self.assertIn(path, pyi_cache)
1325                 self.assertNotIn(path, normal_cache)
1326
1327     def test_pipe_force_py36(self) -> None:
1328         source, expected = read_data("force_py36")
1329         result = CliRunner().invoke(
1330             black.main,
1331             ["-", "-q", "--target-version=py36"],
1332             input=BytesIO(source.encode("utf8")),
1333         )
1334         self.assertEqual(result.exit_code, 0)
1335         actual = result.output
1336         self.assertFormatEqual(actual, expected)
1337
1338     def test_include_exclude(self) -> None:
1339         path = THIS_DIR / "data" / "include_exclude_tests"
1340         include = re.compile(r"\.pyi?$")
1341         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1342         report = black.Report()
1343         sources: List[Path] = []
1344         expected = [
1345             Path(path / "b/dont_exclude/a.py"),
1346             Path(path / "b/dont_exclude/a.pyi"),
1347         ]
1348         this_abs = THIS_DIR.resolve()
1349         sources.extend(
1350             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1351         )
1352         self.assertEqual(sorted(expected), sorted(sources))
1353
1354     def test_empty_include(self) -> None:
1355         path = THIS_DIR / "data" / "include_exclude_tests"
1356         report = black.Report()
1357         empty = re.compile(r"")
1358         sources: List[Path] = []
1359         expected = [
1360             Path(path / "b/exclude/a.pie"),
1361             Path(path / "b/exclude/a.py"),
1362             Path(path / "b/exclude/a.pyi"),
1363             Path(path / "b/dont_exclude/a.pie"),
1364             Path(path / "b/dont_exclude/a.py"),
1365             Path(path / "b/dont_exclude/a.pyi"),
1366             Path(path / "b/.definitely_exclude/a.pie"),
1367             Path(path / "b/.definitely_exclude/a.py"),
1368             Path(path / "b/.definitely_exclude/a.pyi"),
1369         ]
1370         this_abs = THIS_DIR.resolve()
1371         sources.extend(
1372             black.gen_python_files_in_dir(
1373                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1374             )
1375         )
1376         self.assertEqual(sorted(expected), sorted(sources))
1377
1378     def test_empty_exclude(self) -> None:
1379         path = THIS_DIR / "data" / "include_exclude_tests"
1380         report = black.Report()
1381         empty = re.compile(r"")
1382         sources: List[Path] = []
1383         expected = [
1384             Path(path / "b/dont_exclude/a.py"),
1385             Path(path / "b/dont_exclude/a.pyi"),
1386             Path(path / "b/exclude/a.py"),
1387             Path(path / "b/exclude/a.pyi"),
1388             Path(path / "b/.definitely_exclude/a.py"),
1389             Path(path / "b/.definitely_exclude/a.pyi"),
1390         ]
1391         this_abs = THIS_DIR.resolve()
1392         sources.extend(
1393             black.gen_python_files_in_dir(
1394                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1395             )
1396         )
1397         self.assertEqual(sorted(expected), sorted(sources))
1398
1399     def test_invalid_include_exclude(self) -> None:
1400         for option in ["--include", "--exclude"]:
1401             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1402
1403     def test_preserves_line_endings(self) -> None:
1404         with TemporaryDirectory() as workspace:
1405             test_file = Path(workspace) / "test.py"
1406             for nl in ["\n", "\r\n"]:
1407                 contents = nl.join(["def f(  ):", "    pass"])
1408                 test_file.write_bytes(contents.encode())
1409                 ff(test_file, write_back=black.WriteBack.YES)
1410                 updated_contents: bytes = test_file.read_bytes()
1411                 self.assertIn(nl.encode(), updated_contents)
1412                 if nl == "\n":
1413                     self.assertNotIn(b"\r\n", updated_contents)
1414
1415     def test_preserves_line_endings_via_stdin(self) -> None:
1416         for nl in ["\n", "\r\n"]:
1417             contents = nl.join(["def f(  ):", "    pass"])
1418             runner = BlackRunner()
1419             result = runner.invoke(
1420                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1421             )
1422             self.assertEqual(result.exit_code, 0)
1423             output = runner.stdout_bytes
1424             self.assertIn(nl.encode("utf8"), output)
1425             if nl == "\n":
1426                 self.assertNotIn(b"\r\n", output)
1427
1428     def test_assert_equivalent_different_asts(self) -> None:
1429         with self.assertRaises(AssertionError):
1430             black.assert_equivalent("{}", "None")
1431
1432     def test_symlink_out_of_root_directory(self) -> None:
1433         path = MagicMock()
1434         root = THIS_DIR
1435         child = MagicMock()
1436         include = re.compile(black.DEFAULT_INCLUDES)
1437         exclude = re.compile(black.DEFAULT_EXCLUDES)
1438         report = black.Report()
1439         # `child` should behave like a symlink which resolved path is clearly
1440         # outside of the `root` directory.
1441         path.iterdir.return_value = [child]
1442         child.resolve.return_value = Path("/a/b/c")
1443         child.is_symlink.return_value = True
1444         try:
1445             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1446         except ValueError as ve:
1447             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1448         path.iterdir.assert_called_once()
1449         child.resolve.assert_called_once()
1450         child.is_symlink.assert_called_once()
1451         # `child` should behave like a strange file which resolved path is clearly
1452         # outside of the `root` directory.
1453         child.is_symlink.return_value = False
1454         with self.assertRaises(ValueError):
1455             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1456         path.iterdir.assert_called()
1457         self.assertEqual(path.iterdir.call_count, 2)
1458         child.resolve.assert_called()
1459         self.assertEqual(child.resolve.call_count, 2)
1460         child.is_symlink.assert_called()
1461         self.assertEqual(child.is_symlink.call_count, 2)
1462
1463     def test_shhh_click(self) -> None:
1464         try:
1465             from click import _unicodefun  # type: ignore
1466         except ModuleNotFoundError:
1467             self.skipTest("Incompatible Click version")
1468         if not hasattr(_unicodefun, "_verify_python3_env"):
1469             self.skipTest("Incompatible Click version")
1470         # First, let's see if Click is crashing with a preferred ASCII charset.
1471         with patch("locale.getpreferredencoding") as gpe:
1472             gpe.return_value = "ASCII"
1473             with self.assertRaises(RuntimeError):
1474                 _unicodefun._verify_python3_env()
1475         # Now, let's silence Click...
1476         black.patch_click()
1477         # ...and confirm it's silent.
1478         with patch("locale.getpreferredencoding") as gpe:
1479             gpe.return_value = "ASCII"
1480             try:
1481                 _unicodefun._verify_python3_env()
1482             except RuntimeError as re:
1483                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1484
1485     def test_root_logger_not_used_directly(self) -> None:
1486         def fail(*args: Any, **kwargs: Any) -> None:
1487             self.fail("Record created with root logger")
1488
1489         with patch.multiple(
1490             logging.root,
1491             debug=fail,
1492             info=fail,
1493             warning=fail,
1494             error=fail,
1495             critical=fail,
1496             log=fail,
1497         ):
1498             ff(THIS_FILE)
1499
1500     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1501     @async_test
1502     async def test_blackd_request_needs_formatting(self) -> None:
1503         app = blackd.make_app()
1504         async with TestClient(TestServer(app)) as client:
1505             response = await client.post("/", data=b"print('hello world')")
1506             self.assertEqual(response.status, 200)
1507             self.assertEqual(response.charset, "utf8")
1508             self.assertEqual(await response.read(), b'print("hello world")\n')
1509
1510     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1511     @async_test
1512     async def test_blackd_request_no_change(self) -> None:
1513         app = blackd.make_app()
1514         async with TestClient(TestServer(app)) as client:
1515             response = await client.post("/", data=b'print("hello world")\n')
1516             self.assertEqual(response.status, 204)
1517             self.assertEqual(await response.read(), b"")
1518
1519     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1520     @async_test
1521     async def test_blackd_request_syntax_error(self) -> None:
1522         app = blackd.make_app()
1523         async with TestClient(TestServer(app)) as client:
1524             response = await client.post("/", data=b"what even ( is")
1525             self.assertEqual(response.status, 400)
1526             content = await response.text()
1527             self.assertTrue(
1528                 content.startswith("Cannot parse"),
1529                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1530             )
1531
1532     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1533     @async_test
1534     async def test_blackd_unsupported_version(self) -> None:
1535         app = blackd.make_app()
1536         async with TestClient(TestServer(app)) as client:
1537             response = await client.post(
1538                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1539             )
1540             self.assertEqual(response.status, 501)
1541
1542     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1543     @async_test
1544     async def test_blackd_supported_version(self) -> None:
1545         app = blackd.make_app()
1546         async with TestClient(TestServer(app)) as client:
1547             response = await client.post(
1548                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1549             )
1550             self.assertEqual(response.status, 200)
1551
1552     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1553     @async_test
1554     async def test_blackd_invalid_python_variant(self) -> None:
1555         app = blackd.make_app()
1556         async with TestClient(TestServer(app)) as client:
1557
1558             async def check(header_value: str, expected_status: int = 400) -> None:
1559                 response = await client.post(
1560                     "/",
1561                     data=b"what",
1562                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1563                 )
1564                 self.assertEqual(response.status, expected_status)
1565
1566             await check("lol")
1567             await check("ruby3.5")
1568             await check("pyi3.6")
1569             await check("py1.5")
1570             await check("2.8")
1571             await check("py2.8")
1572             await check("3.0")
1573             await check("pypy3.0")
1574             await check("jython3.4")
1575
1576     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1577     @async_test
1578     async def test_blackd_pyi(self) -> None:
1579         app = blackd.make_app()
1580         async with TestClient(TestServer(app)) as client:
1581             source, expected = read_data("stub.pyi")
1582             response = await client.post(
1583                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1584             )
1585             self.assertEqual(response.status, 200)
1586             self.assertEqual(await response.text(), expected)
1587
1588     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1589     @async_test
1590     async def test_blackd_python_variant(self) -> None:
1591         app = blackd.make_app()
1592         code = (
1593             "def f(\n"
1594             "    and_has_a_bunch_of,\n"
1595             "    very_long_arguments_too,\n"
1596             "    and_lots_of_them_as_well_lol,\n"
1597             "    **and_very_long_keyword_arguments\n"
1598             "):\n"
1599             "    pass\n"
1600         )
1601         async with TestClient(TestServer(app)) as client:
1602
1603             async def check(header_value: str, expected_status: int) -> None:
1604                 response = await client.post(
1605                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1606                 )
1607                 self.assertEqual(response.status, expected_status)
1608
1609             await check("3.6", 200)
1610             await check("py3.6", 200)
1611             await check("3.6,3.7", 200)
1612             await check("3.6,py3.7", 200)
1613
1614             await check("2", 204)
1615             await check("2.7", 204)
1616             await check("py2.7", 204)
1617             await check("3.4", 204)
1618             await check("py3.4", 204)
1619
1620     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1621     @async_test
1622     async def test_blackd_line_length(self) -> None:
1623         app = blackd.make_app()
1624         async with TestClient(TestServer(app)) as client:
1625             response = await client.post(
1626                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1627             )
1628             self.assertEqual(response.status, 200)
1629
1630     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1631     @async_test
1632     async def test_blackd_invalid_line_length(self) -> None:
1633         app = blackd.make_app()
1634         async with TestClient(TestServer(app)) as client:
1635             response = await client.post(
1636                 "/",
1637                 data=b'print("hello")\n',
1638                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1639             )
1640             self.assertEqual(response.status, 400)
1641
1642     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1643     def test_blackd_main(self) -> None:
1644         with patch("blackd.web.run_app"):
1645             result = CliRunner().invoke(blackd.main, [])
1646             if result.exception is not None:
1647                 raise result.exception
1648             self.assertEqual(result.exit_code, 0)
1649
1650
1651 if __name__ == "__main__":
1652     unittest.main(module="test_black")