]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add change log entry
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_function_trailing_comma(self) -> None:
269         source, expected = read_data("function_trailing_comma")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     @patch("black.dump_to_file", dump_to_stderr)
276     def test_expression(self) -> None:
277         source, expected = read_data("expression")
278         actual = fs(source)
279         self.assertFormatEqual(expected, actual)
280         black.assert_equivalent(source, actual)
281         black.assert_stable(source, actual, black.FileMode())
282
283     @patch("black.dump_to_file", dump_to_stderr)
284     def test_pep_572(self) -> None:
285         source, expected = read_data("pep_572")
286         actual = fs(source)
287         self.assertFormatEqual(expected, actual)
288         black.assert_stable(source, actual, black.FileMode())
289         if sys.version_info >= (3, 8):
290             black.assert_equivalent(source, actual)
291
292     def test_pep_572_version_detection(self) -> None:
293         source, _ = read_data("pep_572")
294         root = black.lib2to3_parse(source)
295         features = black.get_features_used(root)
296         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
297         versions = black.detect_target_versions(root)
298         self.assertIn(black.TargetVersion.PY38, versions)
299
300     def test_expression_ff(self) -> None:
301         source, expected = read_data("expression")
302         tmp_file = Path(black.dump_to_file(source))
303         try:
304             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
305             with open(tmp_file, encoding="utf8") as f:
306                 actual = f.read()
307         finally:
308             os.unlink(tmp_file)
309         self.assertFormatEqual(expected, actual)
310         with patch("black.dump_to_file", dump_to_stderr):
311             black.assert_equivalent(source, actual)
312             black.assert_stable(source, actual, black.FileMode())
313
314     def test_expression_diff(self) -> None:
315         source, _ = read_data("expression.py")
316         expected, _ = read_data("expression.diff")
317         tmp_file = Path(black.dump_to_file(source))
318         diff_header = re.compile(
319             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
320             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
321         )
322         try:
323             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
324             self.assertEqual(result.exit_code, 0)
325         finally:
326             os.unlink(tmp_file)
327         actual = result.output
328         actual = diff_header.sub("[Deterministic header]", actual)
329         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
330         if expected != actual:
331             dump = black.dump_to_file(actual)
332             msg = (
333                 f"Expected diff isn't equal to the actual. If you made changes "
334                 f"to expression.py and this is an anticipated difference, "
335                 f"overwrite tests/data/expression.diff with {dump}"
336             )
337             self.assertEqual(expected, actual, msg)
338
339     @patch("black.dump_to_file", dump_to_stderr)
340     def test_fstring(self) -> None:
341         source, expected = read_data("fstring")
342         actual = fs(source)
343         self.assertFormatEqual(expected, actual)
344         black.assert_equivalent(source, actual)
345         black.assert_stable(source, actual, black.FileMode())
346
347     @patch("black.dump_to_file", dump_to_stderr)
348     def test_pep_570(self) -> None:
349         source, expected = read_data("pep_570")
350         actual = fs(source)
351         self.assertFormatEqual(expected, actual)
352         black.assert_stable(source, actual, black.FileMode())
353         if sys.version_info >= (3, 8):
354             black.assert_equivalent(source, actual)
355
356     def test_detect_pos_only_arguments(self) -> None:
357         source, _ = read_data("pep_570")
358         root = black.lib2to3_parse(source)
359         features = black.get_features_used(root)
360         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
361         versions = black.detect_target_versions(root)
362         self.assertIn(black.TargetVersion.PY38, versions)
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_string_quotes(self) -> None:
366         source, expected = read_data("string_quotes")
367         actual = fs(source)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, black.FileMode())
371         mode = black.FileMode(string_normalization=False)
372         not_normalized = fs(source, mode=mode)
373         self.assertFormatEqual(source, not_normalized)
374         black.assert_equivalent(source, not_normalized)
375         black.assert_stable(source, not_normalized, mode=mode)
376
377     @patch("black.dump_to_file", dump_to_stderr)
378     def test_slices(self) -> None:
379         source, expected = read_data("slices")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         black.assert_equivalent(source, actual)
383         black.assert_stable(source, actual, black.FileMode())
384
385     @patch("black.dump_to_file", dump_to_stderr)
386     def test_comments(self) -> None:
387         source, expected = read_data("comments")
388         actual = fs(source)
389         self.assertFormatEqual(expected, actual)
390         black.assert_equivalent(source, actual)
391         black.assert_stable(source, actual, black.FileMode())
392
393     @patch("black.dump_to_file", dump_to_stderr)
394     def test_comments2(self) -> None:
395         source, expected = read_data("comments2")
396         actual = fs(source)
397         self.assertFormatEqual(expected, actual)
398         black.assert_equivalent(source, actual)
399         black.assert_stable(source, actual, black.FileMode())
400
401     @patch("black.dump_to_file", dump_to_stderr)
402     def test_comments3(self) -> None:
403         source, expected = read_data("comments3")
404         actual = fs(source)
405         self.assertFormatEqual(expected, actual)
406         black.assert_equivalent(source, actual)
407         black.assert_stable(source, actual, black.FileMode())
408
409     @patch("black.dump_to_file", dump_to_stderr)
410     def test_comments4(self) -> None:
411         source, expected = read_data("comments4")
412         actual = fs(source)
413         self.assertFormatEqual(expected, actual)
414         black.assert_equivalent(source, actual)
415         black.assert_stable(source, actual, black.FileMode())
416
417     @patch("black.dump_to_file", dump_to_stderr)
418     def test_comments5(self) -> None:
419         source, expected = read_data("comments5")
420         actual = fs(source)
421         self.assertFormatEqual(expected, actual)
422         black.assert_equivalent(source, actual)
423         black.assert_stable(source, actual, black.FileMode())
424
425     @patch("black.dump_to_file", dump_to_stderr)
426     def test_comments6(self) -> None:
427         source, expected = read_data("comments6")
428         actual = fs(source)
429         self.assertFormatEqual(expected, actual)
430         black.assert_equivalent(source, actual)
431         black.assert_stable(source, actual, black.FileMode())
432
433     @patch("black.dump_to_file", dump_to_stderr)
434     def test_comments7(self) -> None:
435         source, expected = read_data("comments7")
436         actual = fs(source)
437         self.assertFormatEqual(expected, actual)
438         black.assert_equivalent(source, actual)
439         black.assert_stable(source, actual, black.FileMode())
440
441     @patch("black.dump_to_file", dump_to_stderr)
442     def test_comment_after_escaped_newline(self) -> None:
443         source, expected = read_data("comment_after_escaped_newline")
444         actual = fs(source)
445         self.assertFormatEqual(expected, actual)
446         black.assert_equivalent(source, actual)
447         black.assert_stable(source, actual, black.FileMode())
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_cantfit(self) -> None:
451         source, expected = read_data("cantfit")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, black.FileMode())
456
457     @patch("black.dump_to_file", dump_to_stderr)
458     def test_import_spacing(self) -> None:
459         source, expected = read_data("import_spacing")
460         actual = fs(source)
461         self.assertFormatEqual(expected, actual)
462         black.assert_equivalent(source, actual)
463         black.assert_stable(source, actual, black.FileMode())
464
465     @patch("black.dump_to_file", dump_to_stderr)
466     def test_composition(self) -> None:
467         source, expected = read_data("composition")
468         actual = fs(source)
469         self.assertFormatEqual(expected, actual)
470         black.assert_equivalent(source, actual)
471         black.assert_stable(source, actual, black.FileMode())
472
473     @patch("black.dump_to_file", dump_to_stderr)
474     def test_empty_lines(self) -> None:
475         source, expected = read_data("empty_lines")
476         actual = fs(source)
477         self.assertFormatEqual(expected, actual)
478         black.assert_equivalent(source, actual)
479         black.assert_stable(source, actual, black.FileMode())
480
481     @patch("black.dump_to_file", dump_to_stderr)
482     def test_remove_parens(self) -> None:
483         source, expected = read_data("remove_parens")
484         actual = fs(source)
485         self.assertFormatEqual(expected, actual)
486         black.assert_equivalent(source, actual)
487         black.assert_stable(source, actual, black.FileMode())
488
489     @patch("black.dump_to_file", dump_to_stderr)
490     def test_string_prefixes(self) -> None:
491         source, expected = read_data("string_prefixes")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_numeric_literals(self) -> None:
499         source, expected = read_data("numeric_literals")
500         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
501         actual = fs(source, mode=mode)
502         self.assertFormatEqual(expected, actual)
503         black.assert_equivalent(source, actual)
504         black.assert_stable(source, actual, mode)
505
506     @patch("black.dump_to_file", dump_to_stderr)
507     def test_numeric_literals_ignoring_underscores(self) -> None:
508         source, expected = read_data("numeric_literals_skip_underscores")
509         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
510         actual = fs(source, mode=mode)
511         self.assertFormatEqual(expected, actual)
512         black.assert_equivalent(source, actual)
513         black.assert_stable(source, actual, mode)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_numeric_literals_py2(self) -> None:
517         source, expected = read_data("numeric_literals_py2")
518         actual = fs(source)
519         self.assertFormatEqual(expected, actual)
520         black.assert_stable(source, actual, black.FileMode())
521
522     @patch("black.dump_to_file", dump_to_stderr)
523     def test_python2(self) -> None:
524         source, expected = read_data("python2")
525         actual = fs(source)
526         self.assertFormatEqual(expected, actual)
527         black.assert_equivalent(source, actual)
528         black.assert_stable(source, actual, black.FileMode())
529
530     @patch("black.dump_to_file", dump_to_stderr)
531     def test_python2_print_function(self) -> None:
532         source, expected = read_data("python2_print_function")
533         mode = black.FileMode(target_versions={TargetVersion.PY27})
534         actual = fs(source, mode=mode)
535         self.assertFormatEqual(expected, actual)
536         black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, mode)
538
539     @patch("black.dump_to_file", dump_to_stderr)
540     def test_python2_unicode_literals(self) -> None:
541         source, expected = read_data("python2_unicode_literals")
542         actual = fs(source)
543         self.assertFormatEqual(expected, actual)
544         black.assert_equivalent(source, actual)
545         black.assert_stable(source, actual, black.FileMode())
546
547     @patch("black.dump_to_file", dump_to_stderr)
548     def test_stub(self) -> None:
549         mode = black.FileMode(is_pyi=True)
550         source, expected = read_data("stub.pyi")
551         actual = fs(source, mode=mode)
552         self.assertFormatEqual(expected, actual)
553         black.assert_stable(source, actual, mode)
554
555     @patch("black.dump_to_file", dump_to_stderr)
556     def test_async_as_identifier(self) -> None:
557         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
558         source, expected = read_data("async_as_identifier")
559         actual = fs(source)
560         self.assertFormatEqual(expected, actual)
561         major, minor = sys.version_info[:2]
562         if major < 3 or (major <= 3 and minor < 7):
563             black.assert_equivalent(source, actual)
564         black.assert_stable(source, actual, black.FileMode())
565         # ensure black can parse this when the target is 3.6
566         self.invokeBlack([str(source_path), "--target-version", "py36"])
567         # but not on 3.7, because async/await is no longer an identifier
568         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
569
570     @patch("black.dump_to_file", dump_to_stderr)
571     def test_python37(self) -> None:
572         source_path = (THIS_DIR / "data" / "python37.py").resolve()
573         source, expected = read_data("python37")
574         actual = fs(source)
575         self.assertFormatEqual(expected, actual)
576         major, minor = sys.version_info[:2]
577         if major > 3 or (major == 3 and minor >= 7):
578             black.assert_equivalent(source, actual)
579         black.assert_stable(source, actual, black.FileMode())
580         # ensure black can parse this when the target is 3.7
581         self.invokeBlack([str(source_path), "--target-version", "py37"])
582         # but not on 3.6, because we use async as a reserved keyword
583         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
584
585     @patch("black.dump_to_file", dump_to_stderr)
586     def test_fmtonoff(self) -> None:
587         source, expected = read_data("fmtonoff")
588         actual = fs(source)
589         self.assertFormatEqual(expected, actual)
590         black.assert_equivalent(source, actual)
591         black.assert_stable(source, actual, black.FileMode())
592
593     @patch("black.dump_to_file", dump_to_stderr)
594     def test_fmtonoff2(self) -> None:
595         source, expected = read_data("fmtonoff2")
596         actual = fs(source)
597         self.assertFormatEqual(expected, actual)
598         black.assert_equivalent(source, actual)
599         black.assert_stable(source, actual, black.FileMode())
600
601     @patch("black.dump_to_file", dump_to_stderr)
602     def test_remove_empty_parentheses_after_class(self) -> None:
603         source, expected = read_data("class_blank_parentheses")
604         actual = fs(source)
605         self.assertFormatEqual(expected, actual)
606         black.assert_equivalent(source, actual)
607         black.assert_stable(source, actual, black.FileMode())
608
609     @patch("black.dump_to_file", dump_to_stderr)
610     def test_new_line_between_class_and_code(self) -> None:
611         source, expected = read_data("class_methods_new_line")
612         actual = fs(source)
613         self.assertFormatEqual(expected, actual)
614         black.assert_equivalent(source, actual)
615         black.assert_stable(source, actual, black.FileMode())
616
617     @patch("black.dump_to_file", dump_to_stderr)
618     def test_bracket_match(self) -> None:
619         source, expected = read_data("bracketmatch")
620         actual = fs(source)
621         self.assertFormatEqual(expected, actual)
622         black.assert_equivalent(source, actual)
623         black.assert_stable(source, actual, black.FileMode())
624
625     @patch("black.dump_to_file", dump_to_stderr)
626     def test_tuple_assign(self) -> None:
627         source, expected = read_data("tupleassign")
628         actual = fs(source)
629         self.assertFormatEqual(expected, actual)
630         black.assert_equivalent(source, actual)
631         black.assert_stable(source, actual, black.FileMode())
632
633     def test_tab_comment_indentation(self) -> None:
634         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
635         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
636         self.assertFormatEqual(contents_spc, fs(contents_spc))
637         self.assertFormatEqual(contents_spc, fs(contents_tab))
638
639         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
640         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
641         self.assertFormatEqual(contents_spc, fs(contents_spc))
642         self.assertFormatEqual(contents_spc, fs(contents_tab))
643
644         # mixed tabs and spaces (valid Python 2 code)
645         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
646         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
647         self.assertFormatEqual(contents_spc, fs(contents_spc))
648         self.assertFormatEqual(contents_spc, fs(contents_tab))
649
650         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
651         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
652         self.assertFormatEqual(contents_spc, fs(contents_spc))
653         self.assertFormatEqual(contents_spc, fs(contents_tab))
654
655     def test_report_verbose(self) -> None:
656         report = black.Report(verbose=True)
657         out_lines = []
658         err_lines = []
659
660         def out(msg: str, **kwargs: Any) -> None:
661             out_lines.append(msg)
662
663         def err(msg: str, **kwargs: Any) -> None:
664             err_lines.append(msg)
665
666         with patch("black.out", out), patch("black.err", err):
667             report.done(Path("f1"), black.Changed.NO)
668             self.assertEqual(len(out_lines), 1)
669             self.assertEqual(len(err_lines), 0)
670             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
671             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
672             self.assertEqual(report.return_code, 0)
673             report.done(Path("f2"), black.Changed.YES)
674             self.assertEqual(len(out_lines), 2)
675             self.assertEqual(len(err_lines), 0)
676             self.assertEqual(out_lines[-1], "reformatted f2")
677             self.assertEqual(
678                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
679             )
680             report.done(Path("f3"), black.Changed.CACHED)
681             self.assertEqual(len(out_lines), 3)
682             self.assertEqual(len(err_lines), 0)
683             self.assertEqual(
684                 out_lines[-1], "f3 wasn't modified on disk since last run."
685             )
686             self.assertEqual(
687                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
688             )
689             self.assertEqual(report.return_code, 0)
690             report.check = True
691             self.assertEqual(report.return_code, 1)
692             report.check = False
693             report.failed(Path("e1"), "boom")
694             self.assertEqual(len(out_lines), 3)
695             self.assertEqual(len(err_lines), 1)
696             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
697             self.assertEqual(
698                 unstyle(str(report)),
699                 "1 file reformatted, 2 files left unchanged, "
700                 "1 file failed to reformat.",
701             )
702             self.assertEqual(report.return_code, 123)
703             report.done(Path("f3"), black.Changed.YES)
704             self.assertEqual(len(out_lines), 4)
705             self.assertEqual(len(err_lines), 1)
706             self.assertEqual(out_lines[-1], "reformatted f3")
707             self.assertEqual(
708                 unstyle(str(report)),
709                 "2 files reformatted, 2 files left unchanged, "
710                 "1 file failed to reformat.",
711             )
712             self.assertEqual(report.return_code, 123)
713             report.failed(Path("e2"), "boom")
714             self.assertEqual(len(out_lines), 4)
715             self.assertEqual(len(err_lines), 2)
716             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
717             self.assertEqual(
718                 unstyle(str(report)),
719                 "2 files reformatted, 2 files left unchanged, "
720                 "2 files failed to reformat.",
721             )
722             self.assertEqual(report.return_code, 123)
723             report.path_ignored(Path("wat"), "no match")
724             self.assertEqual(len(out_lines), 5)
725             self.assertEqual(len(err_lines), 2)
726             self.assertEqual(out_lines[-1], "wat ignored: no match")
727             self.assertEqual(
728                 unstyle(str(report)),
729                 "2 files reformatted, 2 files left unchanged, "
730                 "2 files failed to reformat.",
731             )
732             self.assertEqual(report.return_code, 123)
733             report.done(Path("f4"), black.Changed.NO)
734             self.assertEqual(len(out_lines), 6)
735             self.assertEqual(len(err_lines), 2)
736             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
737             self.assertEqual(
738                 unstyle(str(report)),
739                 "2 files reformatted, 3 files left unchanged, "
740                 "2 files failed to reformat.",
741             )
742             self.assertEqual(report.return_code, 123)
743             report.check = True
744             self.assertEqual(
745                 unstyle(str(report)),
746                 "2 files would be reformatted, 3 files would be left unchanged, "
747                 "2 files would fail to reformat.",
748             )
749
750     def test_report_quiet(self) -> None:
751         report = black.Report(quiet=True)
752         out_lines = []
753         err_lines = []
754
755         def out(msg: str, **kwargs: Any) -> None:
756             out_lines.append(msg)
757
758         def err(msg: str, **kwargs: Any) -> None:
759             err_lines.append(msg)
760
761         with patch("black.out", out), patch("black.err", err):
762             report.done(Path("f1"), black.Changed.NO)
763             self.assertEqual(len(out_lines), 0)
764             self.assertEqual(len(err_lines), 0)
765             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
766             self.assertEqual(report.return_code, 0)
767             report.done(Path("f2"), black.Changed.YES)
768             self.assertEqual(len(out_lines), 0)
769             self.assertEqual(len(err_lines), 0)
770             self.assertEqual(
771                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
772             )
773             report.done(Path("f3"), black.Changed.CACHED)
774             self.assertEqual(len(out_lines), 0)
775             self.assertEqual(len(err_lines), 0)
776             self.assertEqual(
777                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
778             )
779             self.assertEqual(report.return_code, 0)
780             report.check = True
781             self.assertEqual(report.return_code, 1)
782             report.check = False
783             report.failed(Path("e1"), "boom")
784             self.assertEqual(len(out_lines), 0)
785             self.assertEqual(len(err_lines), 1)
786             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
787             self.assertEqual(
788                 unstyle(str(report)),
789                 "1 file reformatted, 2 files left unchanged, "
790                 "1 file failed to reformat.",
791             )
792             self.assertEqual(report.return_code, 123)
793             report.done(Path("f3"), black.Changed.YES)
794             self.assertEqual(len(out_lines), 0)
795             self.assertEqual(len(err_lines), 1)
796             self.assertEqual(
797                 unstyle(str(report)),
798                 "2 files reformatted, 2 files left unchanged, "
799                 "1 file failed to reformat.",
800             )
801             self.assertEqual(report.return_code, 123)
802             report.failed(Path("e2"), "boom")
803             self.assertEqual(len(out_lines), 0)
804             self.assertEqual(len(err_lines), 2)
805             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
806             self.assertEqual(
807                 unstyle(str(report)),
808                 "2 files reformatted, 2 files left unchanged, "
809                 "2 files failed to reformat.",
810             )
811             self.assertEqual(report.return_code, 123)
812             report.path_ignored(Path("wat"), "no match")
813             self.assertEqual(len(out_lines), 0)
814             self.assertEqual(len(err_lines), 2)
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "2 files reformatted, 2 files left unchanged, "
818                 "2 files failed to reformat.",
819             )
820             self.assertEqual(report.return_code, 123)
821             report.done(Path("f4"), black.Changed.NO)
822             self.assertEqual(len(out_lines), 0)
823             self.assertEqual(len(err_lines), 2)
824             self.assertEqual(
825                 unstyle(str(report)),
826                 "2 files reformatted, 3 files left unchanged, "
827                 "2 files failed to reformat.",
828             )
829             self.assertEqual(report.return_code, 123)
830             report.check = True
831             self.assertEqual(
832                 unstyle(str(report)),
833                 "2 files would be reformatted, 3 files would be left unchanged, "
834                 "2 files would fail to reformat.",
835             )
836
837     def test_report_normal(self) -> None:
838         report = black.Report()
839         out_lines = []
840         err_lines = []
841
842         def out(msg: str, **kwargs: Any) -> None:
843             out_lines.append(msg)
844
845         def err(msg: str, **kwargs: Any) -> None:
846             err_lines.append(msg)
847
848         with patch("black.out", out), patch("black.err", err):
849             report.done(Path("f1"), black.Changed.NO)
850             self.assertEqual(len(out_lines), 0)
851             self.assertEqual(len(err_lines), 0)
852             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
853             self.assertEqual(report.return_code, 0)
854             report.done(Path("f2"), black.Changed.YES)
855             self.assertEqual(len(out_lines), 1)
856             self.assertEqual(len(err_lines), 0)
857             self.assertEqual(out_lines[-1], "reformatted f2")
858             self.assertEqual(
859                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
860             )
861             report.done(Path("f3"), black.Changed.CACHED)
862             self.assertEqual(len(out_lines), 1)
863             self.assertEqual(len(err_lines), 0)
864             self.assertEqual(out_lines[-1], "reformatted f2")
865             self.assertEqual(
866                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
867             )
868             self.assertEqual(report.return_code, 0)
869             report.check = True
870             self.assertEqual(report.return_code, 1)
871             report.check = False
872             report.failed(Path("e1"), "boom")
873             self.assertEqual(len(out_lines), 1)
874             self.assertEqual(len(err_lines), 1)
875             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
876             self.assertEqual(
877                 unstyle(str(report)),
878                 "1 file reformatted, 2 files left unchanged, "
879                 "1 file failed to reformat.",
880             )
881             self.assertEqual(report.return_code, 123)
882             report.done(Path("f3"), black.Changed.YES)
883             self.assertEqual(len(out_lines), 2)
884             self.assertEqual(len(err_lines), 1)
885             self.assertEqual(out_lines[-1], "reformatted f3")
886             self.assertEqual(
887                 unstyle(str(report)),
888                 "2 files reformatted, 2 files left unchanged, "
889                 "1 file failed to reformat.",
890             )
891             self.assertEqual(report.return_code, 123)
892             report.failed(Path("e2"), "boom")
893             self.assertEqual(len(out_lines), 2)
894             self.assertEqual(len(err_lines), 2)
895             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
896             self.assertEqual(
897                 unstyle(str(report)),
898                 "2 files reformatted, 2 files left unchanged, "
899                 "2 files failed to reformat.",
900             )
901             self.assertEqual(report.return_code, 123)
902             report.path_ignored(Path("wat"), "no match")
903             self.assertEqual(len(out_lines), 2)
904             self.assertEqual(len(err_lines), 2)
905             self.assertEqual(
906                 unstyle(str(report)),
907                 "2 files reformatted, 2 files left unchanged, "
908                 "2 files failed to reformat.",
909             )
910             self.assertEqual(report.return_code, 123)
911             report.done(Path("f4"), black.Changed.NO)
912             self.assertEqual(len(out_lines), 2)
913             self.assertEqual(len(err_lines), 2)
914             self.assertEqual(
915                 unstyle(str(report)),
916                 "2 files reformatted, 3 files left unchanged, "
917                 "2 files failed to reformat.",
918             )
919             self.assertEqual(report.return_code, 123)
920             report.check = True
921             self.assertEqual(
922                 unstyle(str(report)),
923                 "2 files would be reformatted, 3 files would be left unchanged, "
924                 "2 files would fail to reformat.",
925             )
926
927     def test_lib2to3_parse(self) -> None:
928         with self.assertRaises(black.InvalidInput):
929             black.lib2to3_parse("invalid syntax")
930
931         straddling = "x + y"
932         black.lib2to3_parse(straddling)
933         black.lib2to3_parse(straddling, {TargetVersion.PY27})
934         black.lib2to3_parse(straddling, {TargetVersion.PY36})
935         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
936
937         py2_only = "print x"
938         black.lib2to3_parse(py2_only)
939         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
940         with self.assertRaises(black.InvalidInput):
941             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
942         with self.assertRaises(black.InvalidInput):
943             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
944
945         py3_only = "exec(x, end=y)"
946         black.lib2to3_parse(py3_only)
947         with self.assertRaises(black.InvalidInput):
948             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
949         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
950         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
951
952     def test_get_features_used(self) -> None:
953         node = black.lib2to3_parse("def f(*, arg): ...\n")
954         self.assertEqual(black.get_features_used(node), set())
955         node = black.lib2to3_parse("def f(*, arg,): ...\n")
956         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
957         node = black.lib2to3_parse("f(*arg,)\n")
958         self.assertEqual(
959             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
960         )
961         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
962         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
963         node = black.lib2to3_parse("123_456\n")
964         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
965         node = black.lib2to3_parse("123456\n")
966         self.assertEqual(black.get_features_used(node), set())
967         source, expected = read_data("function")
968         node = black.lib2to3_parse(source)
969         expected_features = {
970             Feature.TRAILING_COMMA_IN_CALL,
971             Feature.TRAILING_COMMA_IN_DEF,
972             Feature.F_STRINGS,
973         }
974         self.assertEqual(black.get_features_used(node), expected_features)
975         node = black.lib2to3_parse(expected)
976         self.assertEqual(black.get_features_used(node), expected_features)
977         source, expected = read_data("expression")
978         node = black.lib2to3_parse(source)
979         self.assertEqual(black.get_features_used(node), set())
980         node = black.lib2to3_parse(expected)
981         self.assertEqual(black.get_features_used(node), set())
982
983     def test_get_future_imports(self) -> None:
984         node = black.lib2to3_parse("\n")
985         self.assertEqual(set(), black.get_future_imports(node))
986         node = black.lib2to3_parse("from __future__ import black\n")
987         self.assertEqual({"black"}, black.get_future_imports(node))
988         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
989         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
990         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
991         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
992         node = black.lib2to3_parse(
993             "from __future__ import multiple\nfrom __future__ import imports\n"
994         )
995         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
996         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
997         self.assertEqual({"black"}, black.get_future_imports(node))
998         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
999         self.assertEqual({"black"}, black.get_future_imports(node))
1000         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1001         self.assertEqual(set(), black.get_future_imports(node))
1002         node = black.lib2to3_parse("from some.module import black\n")
1003         self.assertEqual(set(), black.get_future_imports(node))
1004         node = black.lib2to3_parse(
1005             "from __future__ import unicode_literals as _unicode_literals"
1006         )
1007         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1008         node = black.lib2to3_parse(
1009             "from __future__ import unicode_literals as _lol, print"
1010         )
1011         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1012
1013     def test_debug_visitor(self) -> None:
1014         source, _ = read_data("debug_visitor.py")
1015         expected, _ = read_data("debug_visitor.out")
1016         out_lines = []
1017         err_lines = []
1018
1019         def out(msg: str, **kwargs: Any) -> None:
1020             out_lines.append(msg)
1021
1022         def err(msg: str, **kwargs: Any) -> None:
1023             err_lines.append(msg)
1024
1025         with patch("black.out", out), patch("black.err", err):
1026             black.DebugVisitor.show(source)
1027         actual = "\n".join(out_lines) + "\n"
1028         log_name = ""
1029         if expected != actual:
1030             log_name = black.dump_to_file(*out_lines)
1031         self.assertEqual(
1032             expected,
1033             actual,
1034             f"AST print out is different. Actual version dumped to {log_name}",
1035         )
1036
1037     def test_format_file_contents(self) -> None:
1038         empty = ""
1039         mode = black.FileMode()
1040         with self.assertRaises(black.NothingChanged):
1041             black.format_file_contents(empty, mode=mode, fast=False)
1042         just_nl = "\n"
1043         with self.assertRaises(black.NothingChanged):
1044             black.format_file_contents(just_nl, mode=mode, fast=False)
1045         same = "l = [1, 2, 3]\n"
1046         with self.assertRaises(black.NothingChanged):
1047             black.format_file_contents(same, mode=mode, fast=False)
1048         different = "l = [1,2,3]"
1049         expected = same
1050         actual = black.format_file_contents(different, mode=mode, fast=False)
1051         self.assertEqual(expected, actual)
1052         invalid = "return if you can"
1053         with self.assertRaises(black.InvalidInput) as e:
1054             black.format_file_contents(invalid, mode=mode, fast=False)
1055         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1056
1057     def test_endmarker(self) -> None:
1058         n = black.lib2to3_parse("\n")
1059         self.assertEqual(n.type, black.syms.file_input)
1060         self.assertEqual(len(n.children), 1)
1061         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1062
1063     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1064     def test_assertFormatEqual(self) -> None:
1065         out_lines = []
1066         err_lines = []
1067
1068         def out(msg: str, **kwargs: Any) -> None:
1069             out_lines.append(msg)
1070
1071         def err(msg: str, **kwargs: Any) -> None:
1072             err_lines.append(msg)
1073
1074         with patch("black.out", out), patch("black.err", err):
1075             with self.assertRaises(AssertionError):
1076                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1077
1078         out_str = "".join(out_lines)
1079         self.assertTrue("Expected tree:" in out_str)
1080         self.assertTrue("Actual tree:" in out_str)
1081         self.assertEqual("".join(err_lines), "")
1082
1083     def test_cache_broken_file(self) -> None:
1084         mode = black.FileMode()
1085         with cache_dir() as workspace:
1086             cache_file = black.get_cache_file(mode)
1087             with cache_file.open("w") as fobj:
1088                 fobj.write("this is not a pickle")
1089             self.assertEqual(black.read_cache(mode), {})
1090             src = (workspace / "test.py").resolve()
1091             with src.open("w") as fobj:
1092                 fobj.write("print('hello')")
1093             self.invokeBlack([str(src)])
1094             cache = black.read_cache(mode)
1095             self.assertIn(src, cache)
1096
1097     def test_cache_single_file_already_cached(self) -> None:
1098         mode = black.FileMode()
1099         with cache_dir() as workspace:
1100             src = (workspace / "test.py").resolve()
1101             with src.open("w") as fobj:
1102                 fobj.write("print('hello')")
1103             black.write_cache({}, [src], mode)
1104             self.invokeBlack([str(src)])
1105             with src.open("r") as fobj:
1106                 self.assertEqual(fobj.read(), "print('hello')")
1107
1108     @event_loop(close=False)
1109     def test_cache_multiple_files(self) -> None:
1110         mode = black.FileMode()
1111         with cache_dir() as workspace, patch(
1112             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1113         ):
1114             one = (workspace / "one.py").resolve()
1115             with one.open("w") as fobj:
1116                 fobj.write("print('hello')")
1117             two = (workspace / "two.py").resolve()
1118             with two.open("w") as fobj:
1119                 fobj.write("print('hello')")
1120             black.write_cache({}, [one], mode)
1121             self.invokeBlack([str(workspace)])
1122             with one.open("r") as fobj:
1123                 self.assertEqual(fobj.read(), "print('hello')")
1124             with two.open("r") as fobj:
1125                 self.assertEqual(fobj.read(), 'print("hello")\n')
1126             cache = black.read_cache(mode)
1127             self.assertIn(one, cache)
1128             self.assertIn(two, cache)
1129
1130     def test_no_cache_when_writeback_diff(self) -> None:
1131         mode = black.FileMode()
1132         with cache_dir() as workspace:
1133             src = (workspace / "test.py").resolve()
1134             with src.open("w") as fobj:
1135                 fobj.write("print('hello')")
1136             self.invokeBlack([str(src), "--diff"])
1137             cache_file = black.get_cache_file(mode)
1138             self.assertFalse(cache_file.exists())
1139
1140     def test_no_cache_when_stdin(self) -> None:
1141         mode = black.FileMode()
1142         with cache_dir():
1143             result = CliRunner().invoke(
1144                 black.main, ["-"], input=BytesIO(b"print('hello')")
1145             )
1146             self.assertEqual(result.exit_code, 0)
1147             cache_file = black.get_cache_file(mode)
1148             self.assertFalse(cache_file.exists())
1149
1150     def test_read_cache_no_cachefile(self) -> None:
1151         mode = black.FileMode()
1152         with cache_dir():
1153             self.assertEqual(black.read_cache(mode), {})
1154
1155     def test_write_cache_read_cache(self) -> None:
1156         mode = black.FileMode()
1157         with cache_dir() as workspace:
1158             src = (workspace / "test.py").resolve()
1159             src.touch()
1160             black.write_cache({}, [src], mode)
1161             cache = black.read_cache(mode)
1162             self.assertIn(src, cache)
1163             self.assertEqual(cache[src], black.get_cache_info(src))
1164
1165     def test_filter_cached(self) -> None:
1166         with TemporaryDirectory() as workspace:
1167             path = Path(workspace)
1168             uncached = (path / "uncached").resolve()
1169             cached = (path / "cached").resolve()
1170             cached_but_changed = (path / "changed").resolve()
1171             uncached.touch()
1172             cached.touch()
1173             cached_but_changed.touch()
1174             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1175             todo, done = black.filter_cached(
1176                 cache, {uncached, cached, cached_but_changed}
1177             )
1178             self.assertEqual(todo, {uncached, cached_but_changed})
1179             self.assertEqual(done, {cached})
1180
1181     def test_write_cache_creates_directory_if_needed(self) -> None:
1182         mode = black.FileMode()
1183         with cache_dir(exists=False) as workspace:
1184             self.assertFalse(workspace.exists())
1185             black.write_cache({}, [], mode)
1186             self.assertTrue(workspace.exists())
1187
1188     @event_loop(close=False)
1189     def test_failed_formatting_does_not_get_cached(self) -> None:
1190         mode = black.FileMode()
1191         with cache_dir() as workspace, patch(
1192             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1193         ):
1194             failing = (workspace / "failing.py").resolve()
1195             with failing.open("w") as fobj:
1196                 fobj.write("not actually python")
1197             clean = (workspace / "clean.py").resolve()
1198             with clean.open("w") as fobj:
1199                 fobj.write('print("hello")\n')
1200             self.invokeBlack([str(workspace)], exit_code=123)
1201             cache = black.read_cache(mode)
1202             self.assertNotIn(failing, cache)
1203             self.assertIn(clean, cache)
1204
1205     def test_write_cache_write_fail(self) -> None:
1206         mode = black.FileMode()
1207         with cache_dir(), patch.object(Path, "open") as mock:
1208             mock.side_effect = OSError
1209             black.write_cache({}, [], mode)
1210
1211     @event_loop(close=False)
1212     def test_check_diff_use_together(self) -> None:
1213         with cache_dir():
1214             # Files which will be reformatted.
1215             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1216             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1217             # Files which will not be reformatted.
1218             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1219             self.invokeBlack([str(src2), "--diff", "--check"])
1220             # Multi file command.
1221             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1222
1223     def test_no_files(self) -> None:
1224         with cache_dir():
1225             # Without an argument, black exits with error code 0.
1226             self.invokeBlack([])
1227
1228     def test_broken_symlink(self) -> None:
1229         with cache_dir() as workspace:
1230             symlink = workspace / "broken_link.py"
1231             try:
1232                 symlink.symlink_to("nonexistent.py")
1233             except OSError as e:
1234                 self.skipTest(f"Can't create symlinks: {e}")
1235             self.invokeBlack([str(workspace.resolve())])
1236
1237     def test_read_cache_line_lengths(self) -> None:
1238         mode = black.FileMode()
1239         short_mode = black.FileMode(line_length=1)
1240         with cache_dir() as workspace:
1241             path = (workspace / "file.py").resolve()
1242             path.touch()
1243             black.write_cache({}, [path], mode)
1244             one = black.read_cache(mode)
1245             self.assertIn(path, one)
1246             two = black.read_cache(short_mode)
1247             self.assertNotIn(path, two)
1248
1249     def test_single_file_force_pyi(self) -> None:
1250         reg_mode = black.FileMode()
1251         pyi_mode = black.FileMode(is_pyi=True)
1252         contents, expected = read_data("force_pyi")
1253         with cache_dir() as workspace:
1254             path = (workspace / "file.py").resolve()
1255             with open(path, "w") as fh:
1256                 fh.write(contents)
1257             self.invokeBlack([str(path), "--pyi"])
1258             with open(path, "r") as fh:
1259                 actual = fh.read()
1260             # verify cache with --pyi is separate
1261             pyi_cache = black.read_cache(pyi_mode)
1262             self.assertIn(path, pyi_cache)
1263             normal_cache = black.read_cache(reg_mode)
1264             self.assertNotIn(path, normal_cache)
1265         self.assertEqual(actual, expected)
1266
1267     @event_loop(close=False)
1268     def test_multi_file_force_pyi(self) -> None:
1269         reg_mode = black.FileMode()
1270         pyi_mode = black.FileMode(is_pyi=True)
1271         contents, expected = read_data("force_pyi")
1272         with cache_dir() as workspace:
1273             paths = [
1274                 (workspace / "file1.py").resolve(),
1275                 (workspace / "file2.py").resolve(),
1276             ]
1277             for path in paths:
1278                 with open(path, "w") as fh:
1279                     fh.write(contents)
1280             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1281             for path in paths:
1282                 with open(path, "r") as fh:
1283                     actual = fh.read()
1284                 self.assertEqual(actual, expected)
1285             # verify cache with --pyi is separate
1286             pyi_cache = black.read_cache(pyi_mode)
1287             normal_cache = black.read_cache(reg_mode)
1288             for path in paths:
1289                 self.assertIn(path, pyi_cache)
1290                 self.assertNotIn(path, normal_cache)
1291
1292     def test_pipe_force_pyi(self) -> None:
1293         source, expected = read_data("force_pyi")
1294         result = CliRunner().invoke(
1295             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1296         )
1297         self.assertEqual(result.exit_code, 0)
1298         actual = result.output
1299         self.assertFormatEqual(actual, expected)
1300
1301     def test_single_file_force_py36(self) -> None:
1302         reg_mode = black.FileMode()
1303         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1304         source, expected = read_data("force_py36")
1305         with cache_dir() as workspace:
1306             path = (workspace / "file.py").resolve()
1307             with open(path, "w") as fh:
1308                 fh.write(source)
1309             self.invokeBlack([str(path), *PY36_ARGS])
1310             with open(path, "r") as fh:
1311                 actual = fh.read()
1312             # verify cache with --target-version is separate
1313             py36_cache = black.read_cache(py36_mode)
1314             self.assertIn(path, py36_cache)
1315             normal_cache = black.read_cache(reg_mode)
1316             self.assertNotIn(path, normal_cache)
1317         self.assertEqual(actual, expected)
1318
1319     @event_loop(close=False)
1320     def test_multi_file_force_py36(self) -> None:
1321         reg_mode = black.FileMode()
1322         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1323         source, expected = read_data("force_py36")
1324         with cache_dir() as workspace:
1325             paths = [
1326                 (workspace / "file1.py").resolve(),
1327                 (workspace / "file2.py").resolve(),
1328             ]
1329             for path in paths:
1330                 with open(path, "w") as fh:
1331                     fh.write(source)
1332             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1333             for path in paths:
1334                 with open(path, "r") as fh:
1335                     actual = fh.read()
1336                 self.assertEqual(actual, expected)
1337             # verify cache with --target-version is separate
1338             pyi_cache = black.read_cache(py36_mode)
1339             normal_cache = black.read_cache(reg_mode)
1340             for path in paths:
1341                 self.assertIn(path, pyi_cache)
1342                 self.assertNotIn(path, normal_cache)
1343
1344     def test_pipe_force_py36(self) -> None:
1345         source, expected = read_data("force_py36")
1346         result = CliRunner().invoke(
1347             black.main,
1348             ["-", "-q", "--target-version=py36"],
1349             input=BytesIO(source.encode("utf8")),
1350         )
1351         self.assertEqual(result.exit_code, 0)
1352         actual = result.output
1353         self.assertFormatEqual(actual, expected)
1354
1355     def test_include_exclude(self) -> None:
1356         path = THIS_DIR / "data" / "include_exclude_tests"
1357         include = re.compile(r"\.pyi?$")
1358         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1359         report = black.Report()
1360         sources: List[Path] = []
1361         expected = [
1362             Path(path / "b/dont_exclude/a.py"),
1363             Path(path / "b/dont_exclude/a.pyi"),
1364         ]
1365         this_abs = THIS_DIR.resolve()
1366         sources.extend(
1367             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1368         )
1369         self.assertEqual(sorted(expected), sorted(sources))
1370
1371     def test_empty_include(self) -> None:
1372         path = THIS_DIR / "data" / "include_exclude_tests"
1373         report = black.Report()
1374         empty = re.compile(r"")
1375         sources: List[Path] = []
1376         expected = [
1377             Path(path / "b/exclude/a.pie"),
1378             Path(path / "b/exclude/a.py"),
1379             Path(path / "b/exclude/a.pyi"),
1380             Path(path / "b/dont_exclude/a.pie"),
1381             Path(path / "b/dont_exclude/a.py"),
1382             Path(path / "b/dont_exclude/a.pyi"),
1383             Path(path / "b/.definitely_exclude/a.pie"),
1384             Path(path / "b/.definitely_exclude/a.py"),
1385             Path(path / "b/.definitely_exclude/a.pyi"),
1386         ]
1387         this_abs = THIS_DIR.resolve()
1388         sources.extend(
1389             black.gen_python_files_in_dir(
1390                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1391             )
1392         )
1393         self.assertEqual(sorted(expected), sorted(sources))
1394
1395     def test_empty_exclude(self) -> None:
1396         path = THIS_DIR / "data" / "include_exclude_tests"
1397         report = black.Report()
1398         empty = re.compile(r"")
1399         sources: List[Path] = []
1400         expected = [
1401             Path(path / "b/dont_exclude/a.py"),
1402             Path(path / "b/dont_exclude/a.pyi"),
1403             Path(path / "b/exclude/a.py"),
1404             Path(path / "b/exclude/a.pyi"),
1405             Path(path / "b/.definitely_exclude/a.py"),
1406             Path(path / "b/.definitely_exclude/a.pyi"),
1407         ]
1408         this_abs = THIS_DIR.resolve()
1409         sources.extend(
1410             black.gen_python_files_in_dir(
1411                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1412             )
1413         )
1414         self.assertEqual(sorted(expected), sorted(sources))
1415
1416     def test_invalid_include_exclude(self) -> None:
1417         for option in ["--include", "--exclude"]:
1418             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1419
1420     def test_preserves_line_endings(self) -> None:
1421         with TemporaryDirectory() as workspace:
1422             test_file = Path(workspace) / "test.py"
1423             for nl in ["\n", "\r\n"]:
1424                 contents = nl.join(["def f(  ):", "    pass"])
1425                 test_file.write_bytes(contents.encode())
1426                 ff(test_file, write_back=black.WriteBack.YES)
1427                 updated_contents: bytes = test_file.read_bytes()
1428                 self.assertIn(nl.encode(), updated_contents)
1429                 if nl == "\n":
1430                     self.assertNotIn(b"\r\n", updated_contents)
1431
1432     def test_preserves_line_endings_via_stdin(self) -> None:
1433         for nl in ["\n", "\r\n"]:
1434             contents = nl.join(["def f(  ):", "    pass"])
1435             runner = BlackRunner()
1436             result = runner.invoke(
1437                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1438             )
1439             self.assertEqual(result.exit_code, 0)
1440             output = runner.stdout_bytes
1441             self.assertIn(nl.encode("utf8"), output)
1442             if nl == "\n":
1443                 self.assertNotIn(b"\r\n", output)
1444
1445     def test_assert_equivalent_different_asts(self) -> None:
1446         with self.assertRaises(AssertionError):
1447             black.assert_equivalent("{}", "None")
1448
1449     def test_symlink_out_of_root_directory(self) -> None:
1450         path = MagicMock()
1451         root = THIS_DIR
1452         child = MagicMock()
1453         include = re.compile(black.DEFAULT_INCLUDES)
1454         exclude = re.compile(black.DEFAULT_EXCLUDES)
1455         report = black.Report()
1456         # `child` should behave like a symlink which resolved path is clearly
1457         # outside of the `root` directory.
1458         path.iterdir.return_value = [child]
1459         child.resolve.return_value = Path("/a/b/c")
1460         child.is_symlink.return_value = True
1461         try:
1462             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1463         except ValueError as ve:
1464             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1465         path.iterdir.assert_called_once()
1466         child.resolve.assert_called_once()
1467         child.is_symlink.assert_called_once()
1468         # `child` should behave like a strange file which resolved path is clearly
1469         # outside of the `root` directory.
1470         child.is_symlink.return_value = False
1471         with self.assertRaises(ValueError):
1472             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1473         path.iterdir.assert_called()
1474         self.assertEqual(path.iterdir.call_count, 2)
1475         child.resolve.assert_called()
1476         self.assertEqual(child.resolve.call_count, 2)
1477         child.is_symlink.assert_called()
1478         self.assertEqual(child.is_symlink.call_count, 2)
1479
1480     def test_shhh_click(self) -> None:
1481         try:
1482             from click import _unicodefun  # type: ignore
1483         except ModuleNotFoundError:
1484             self.skipTest("Incompatible Click version")
1485         if not hasattr(_unicodefun, "_verify_python3_env"):
1486             self.skipTest("Incompatible Click version")
1487         # First, let's see if Click is crashing with a preferred ASCII charset.
1488         with patch("locale.getpreferredencoding") as gpe:
1489             gpe.return_value = "ASCII"
1490             with self.assertRaises(RuntimeError):
1491                 _unicodefun._verify_python3_env()
1492         # Now, let's silence Click...
1493         black.patch_click()
1494         # ...and confirm it's silent.
1495         with patch("locale.getpreferredencoding") as gpe:
1496             gpe.return_value = "ASCII"
1497             try:
1498                 _unicodefun._verify_python3_env()
1499             except RuntimeError as re:
1500                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1501
1502     def test_root_logger_not_used_directly(self) -> None:
1503         def fail(*args: Any, **kwargs: Any) -> None:
1504             self.fail("Record created with root logger")
1505
1506         with patch.multiple(
1507             logging.root,
1508             debug=fail,
1509             info=fail,
1510             warning=fail,
1511             error=fail,
1512             critical=fail,
1513             log=fail,
1514         ):
1515             ff(THIS_FILE)
1516
1517     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1518     @async_test
1519     async def test_blackd_request_needs_formatting(self) -> None:
1520         app = blackd.make_app()
1521         async with TestClient(TestServer(app)) as client:
1522             response = await client.post("/", data=b"print('hello world')")
1523             self.assertEqual(response.status, 200)
1524             self.assertEqual(response.charset, "utf8")
1525             self.assertEqual(await response.read(), b'print("hello world")\n')
1526
1527     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1528     @async_test
1529     async def test_blackd_request_no_change(self) -> None:
1530         app = blackd.make_app()
1531         async with TestClient(TestServer(app)) as client:
1532             response = await client.post("/", data=b'print("hello world")\n')
1533             self.assertEqual(response.status, 204)
1534             self.assertEqual(await response.read(), b"")
1535
1536     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1537     @async_test
1538     async def test_blackd_request_syntax_error(self) -> None:
1539         app = blackd.make_app()
1540         async with TestClient(TestServer(app)) as client:
1541             response = await client.post("/", data=b"what even ( is")
1542             self.assertEqual(response.status, 400)
1543             content = await response.text()
1544             self.assertTrue(
1545                 content.startswith("Cannot parse"),
1546                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1547             )
1548
1549     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1550     @async_test
1551     async def test_blackd_unsupported_version(self) -> None:
1552         app = blackd.make_app()
1553         async with TestClient(TestServer(app)) as client:
1554             response = await client.post(
1555                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1556             )
1557             self.assertEqual(response.status, 501)
1558
1559     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1560     @async_test
1561     async def test_blackd_supported_version(self) -> None:
1562         app = blackd.make_app()
1563         async with TestClient(TestServer(app)) as client:
1564             response = await client.post(
1565                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1566             )
1567             self.assertEqual(response.status, 200)
1568
1569     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1570     @async_test
1571     async def test_blackd_invalid_python_variant(self) -> None:
1572         app = blackd.make_app()
1573         async with TestClient(TestServer(app)) as client:
1574
1575             async def check(header_value: str, expected_status: int = 400) -> None:
1576                 response = await client.post(
1577                     "/",
1578                     data=b"what",
1579                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1580                 )
1581                 self.assertEqual(response.status, expected_status)
1582
1583             await check("lol")
1584             await check("ruby3.5")
1585             await check("pyi3.6")
1586             await check("py1.5")
1587             await check("2.8")
1588             await check("py2.8")
1589             await check("3.0")
1590             await check("pypy3.0")
1591             await check("jython3.4")
1592
1593     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1594     @async_test
1595     async def test_blackd_pyi(self) -> None:
1596         app = blackd.make_app()
1597         async with TestClient(TestServer(app)) as client:
1598             source, expected = read_data("stub.pyi")
1599             response = await client.post(
1600                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1601             )
1602             self.assertEqual(response.status, 200)
1603             self.assertEqual(await response.text(), expected)
1604
1605     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1606     @async_test
1607     async def test_blackd_python_variant(self) -> None:
1608         app = blackd.make_app()
1609         code = (
1610             "def f(\n"
1611             "    and_has_a_bunch_of,\n"
1612             "    very_long_arguments_too,\n"
1613             "    and_lots_of_them_as_well_lol,\n"
1614             "    **and_very_long_keyword_arguments\n"
1615             "):\n"
1616             "    pass\n"
1617         )
1618         async with TestClient(TestServer(app)) as client:
1619
1620             async def check(header_value: str, expected_status: int) -> None:
1621                 response = await client.post(
1622                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1623                 )
1624                 self.assertEqual(response.status, expected_status)
1625
1626             await check("3.6", 200)
1627             await check("py3.6", 200)
1628             await check("3.6,3.7", 200)
1629             await check("3.6,py3.7", 200)
1630
1631             await check("2", 204)
1632             await check("2.7", 204)
1633             await check("py2.7", 204)
1634             await check("3.4", 204)
1635             await check("py3.4", 204)
1636
1637     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1638     @async_test
1639     async def test_blackd_line_length(self) -> None:
1640         app = blackd.make_app()
1641         async with TestClient(TestServer(app)) as client:
1642             response = await client.post(
1643                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1644             )
1645             self.assertEqual(response.status, 200)
1646
1647     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1648     @async_test
1649     async def test_blackd_invalid_line_length(self) -> None:
1650         app = blackd.make_app()
1651         async with TestClient(TestServer(app)) as client:
1652             response = await client.post(
1653                 "/",
1654                 data=b'print("hello")\n',
1655                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1656             )
1657             self.assertEqual(response.status, 400)
1658
1659     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1660     def test_blackd_main(self) -> None:
1661         with patch("blackd.web.run_app"):
1662             result = CliRunner().invoke(blackd.main, [])
1663             if result.exception is not None:
1664                 raise result.exception
1665             self.assertEqual(result.exit_code, 0)
1666
1667
1668 if __name__ == "__main__":
1669     unittest.main(module="test_black")