]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

7b3a8b66e2b3bd448676989056e39dc9aed83617
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 @contextmanager
115 def skip_if_exception(e: str) -> Iterator[None]:
116     try:
117         yield
118     except Exception as exc:
119         if exc.__class__.__name__ == e:
120             unittest.skip(f"Encountered expected exception {exc}, skipping")
121
122
123 class BlackRunner(CliRunner):
124     """Modify CliRunner so that stderr is not merged with stdout.
125
126     This is a hack that can be removed once we depend on Click 7.x"""
127
128     def __init__(self) -> None:
129         self.stderrbuf = BytesIO()
130         self.stdoutbuf = BytesIO()
131         self.stdout_bytes = b""
132         self.stderr_bytes = b""
133         super().__init__()
134
135     @contextmanager
136     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
137         with super().isolation(*args, **kwargs) as output:
138             try:
139                 hold_stderr = sys.stderr
140                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
141                 yield output
142             finally:
143                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
144                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
145                 sys.stderr = hold_stderr
146
147
148 class BlackTestCase(unittest.TestCase):
149     maxDiff = None
150
151     def assertFormatEqual(self, expected: str, actual: str) -> None:
152         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
153             bdv: black.DebugVisitor[Any]
154             black.out("Expected tree:", fg="green")
155             try:
156                 exp_node = black.lib2to3_parse(expected)
157                 bdv = black.DebugVisitor()
158                 list(bdv.visit(exp_node))
159             except Exception as ve:
160                 black.err(str(ve))
161             black.out("Actual tree:", fg="red")
162             try:
163                 exp_node = black.lib2to3_parse(actual)
164                 bdv = black.DebugVisitor()
165                 list(bdv.visit(exp_node))
166             except Exception as ve:
167                 black.err(str(ve))
168         self.assertEqual(expected, actual)
169
170     def invokeBlack(
171         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
172     ) -> None:
173         runner = BlackRunner()
174         if ignore_config:
175             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
176         result = runner.invoke(black.main, args)
177         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
178
179     @patch("black.dump_to_file", dump_to_stderr)
180     def test_empty(self) -> None:
181         source = expected = ""
182         actual = fs(source)
183         self.assertFormatEqual(expected, actual)
184         black.assert_equivalent(source, actual)
185         black.assert_stable(source, actual, black.FileMode())
186
187     def test_empty_ff(self) -> None:
188         expected = ""
189         tmp_file = Path(black.dump_to_file())
190         try:
191             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
192             with open(tmp_file, encoding="utf8") as f:
193                 actual = f.read()
194         finally:
195             os.unlink(tmp_file)
196         self.assertFormatEqual(expected, actual)
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_self(self) -> None:
200         source, expected = read_data("test_black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_FILE))
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def test_black(self) -> None:
209         source, expected = read_data("../black", data=False)
210         actual = fs(source)
211         self.assertFormatEqual(expected, actual)
212         black.assert_equivalent(source, actual)
213         black.assert_stable(source, actual, black.FileMode())
214         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
215
216     def test_piping(self) -> None:
217         source, expected = read_data("../black", data=False)
218         result = BlackRunner().invoke(
219             black.main,
220             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
221             input=BytesIO(source.encode("utf8")),
222         )
223         self.assertEqual(result.exit_code, 0)
224         self.assertFormatEqual(expected, result.output)
225         black.assert_equivalent(source, result.output)
226         black.assert_stable(source, result.output, black.FileMode())
227
228     def test_piping_diff(self) -> None:
229         diff_header = re.compile(
230             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
231             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
232         )
233         source, _ = read_data("expression.py")
234         expected, _ = read_data("expression.diff")
235         config = THIS_DIR / "data" / "empty_pyproject.toml"
236         args = [
237             "-",
238             "--fast",
239             f"--line-length={black.DEFAULT_LINE_LENGTH}",
240             "--diff",
241             f"--config={config}",
242         ]
243         result = BlackRunner().invoke(
244             black.main, args, input=BytesIO(source.encode("utf8"))
245         )
246         self.assertEqual(result.exit_code, 0)
247         actual = diff_header.sub("[Deterministic header]", result.output)
248         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
249         self.assertEqual(expected, actual)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_setup(self) -> None:
253         source, expected = read_data("../setup", data=False)
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
259
260     @patch("black.dump_to_file", dump_to_stderr)
261     def test_function(self) -> None:
262         source, expected = read_data("function")
263         actual = fs(source)
264         self.assertFormatEqual(expected, actual)
265         black.assert_equivalent(source, actual)
266         black.assert_stable(source, actual, black.FileMode())
267
268     @patch("black.dump_to_file", dump_to_stderr)
269     def test_function2(self) -> None:
270         source, expected = read_data("function2")
271         actual = fs(source)
272         self.assertFormatEqual(expected, actual)
273         black.assert_equivalent(source, actual)
274         black.assert_stable(source, actual, black.FileMode())
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_function_trailing_comma(self) -> None:
278         source, expected = read_data("function_trailing_comma")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_equivalent(source, actual)
282         black.assert_stable(source, actual, black.FileMode())
283
284     @patch("black.dump_to_file", dump_to_stderr)
285     def test_expression(self) -> None:
286         source, expected = read_data("expression")
287         actual = fs(source)
288         self.assertFormatEqual(expected, actual)
289         black.assert_equivalent(source, actual)
290         black.assert_stable(source, actual, black.FileMode())
291
292     @patch("black.dump_to_file", dump_to_stderr)
293     def test_pep_572(self) -> None:
294         source, expected = read_data("pep_572")
295         actual = fs(source)
296         self.assertFormatEqual(expected, actual)
297         black.assert_stable(source, actual, black.FileMode())
298         if sys.version_info >= (3, 8):
299             black.assert_equivalent(source, actual)
300
301     def test_pep_572_version_detection(self) -> None:
302         source, _ = read_data("pep_572")
303         root = black.lib2to3_parse(source)
304         features = black.get_features_used(root)
305         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
306         versions = black.detect_target_versions(root)
307         self.assertIn(black.TargetVersion.PY38, versions)
308
309     def test_expression_ff(self) -> None:
310         source, expected = read_data("expression")
311         tmp_file = Path(black.dump_to_file(source))
312         try:
313             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
314             with open(tmp_file, encoding="utf8") as f:
315                 actual = f.read()
316         finally:
317             os.unlink(tmp_file)
318         self.assertFormatEqual(expected, actual)
319         with patch("black.dump_to_file", dump_to_stderr):
320             black.assert_equivalent(source, actual)
321             black.assert_stable(source, actual, black.FileMode())
322
323     def test_expression_diff(self) -> None:
324         source, _ = read_data("expression.py")
325         expected, _ = read_data("expression.diff")
326         tmp_file = Path(black.dump_to_file(source))
327         diff_header = re.compile(
328             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
329             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
330         )
331         try:
332             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
333             self.assertEqual(result.exit_code, 0)
334         finally:
335             os.unlink(tmp_file)
336         actual = result.output
337         actual = diff_header.sub("[Deterministic header]", actual)
338         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
339         if expected != actual:
340             dump = black.dump_to_file(actual)
341             msg = (
342                 f"Expected diff isn't equal to the actual. If you made changes "
343                 f"to expression.py and this is an anticipated difference, "
344                 f"overwrite tests/data/expression.diff with {dump}"
345             )
346             self.assertEqual(expected, actual, msg)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_fstring(self) -> None:
350         source, expected = read_data("fstring")
351         actual = fs(source)
352         self.assertFormatEqual(expected, actual)
353         black.assert_equivalent(source, actual)
354         black.assert_stable(source, actual, black.FileMode())
355
356     @patch("black.dump_to_file", dump_to_stderr)
357     def test_pep_570(self) -> None:
358         source, expected = read_data("pep_570")
359         actual = fs(source)
360         self.assertFormatEqual(expected, actual)
361         black.assert_stable(source, actual, black.FileMode())
362         if sys.version_info >= (3, 8):
363             black.assert_equivalent(source, actual)
364
365     def test_detect_pos_only_arguments(self) -> None:
366         source, _ = read_data("pep_570")
367         root = black.lib2to3_parse(source)
368         features = black.get_features_used(root)
369         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
370         versions = black.detect_target_versions(root)
371         self.assertIn(black.TargetVersion.PY38, versions)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_string_quotes(self) -> None:
375         source, expected = read_data("string_quotes")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, black.FileMode())
380         mode = black.FileMode(string_normalization=False)
381         not_normalized = fs(source, mode=mode)
382         self.assertFormatEqual(source, not_normalized)
383         black.assert_equivalent(source, not_normalized)
384         black.assert_stable(source, not_normalized, mode=mode)
385
386     @patch("black.dump_to_file", dump_to_stderr)
387     def test_slices(self) -> None:
388         source, expected = read_data("slices")
389         actual = fs(source)
390         self.assertFormatEqual(expected, actual)
391         black.assert_equivalent(source, actual)
392         black.assert_stable(source, actual, black.FileMode())
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_comments(self) -> None:
396         source, expected = read_data("comments")
397         actual = fs(source)
398         self.assertFormatEqual(expected, actual)
399         black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, black.FileMode())
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_comments2(self) -> None:
404         source, expected = read_data("comments2")
405         actual = fs(source)
406         self.assertFormatEqual(expected, actual)
407         black.assert_equivalent(source, actual)
408         black.assert_stable(source, actual, black.FileMode())
409
410     @patch("black.dump_to_file", dump_to_stderr)
411     def test_comments3(self) -> None:
412         source, expected = read_data("comments3")
413         actual = fs(source)
414         self.assertFormatEqual(expected, actual)
415         black.assert_equivalent(source, actual)
416         black.assert_stable(source, actual, black.FileMode())
417
418     @patch("black.dump_to_file", dump_to_stderr)
419     def test_comments4(self) -> None:
420         source, expected = read_data("comments4")
421         actual = fs(source)
422         self.assertFormatEqual(expected, actual)
423         black.assert_equivalent(source, actual)
424         black.assert_stable(source, actual, black.FileMode())
425
426     @patch("black.dump_to_file", dump_to_stderr)
427     def test_comments5(self) -> None:
428         source, expected = read_data("comments5")
429         actual = fs(source)
430         self.assertFormatEqual(expected, actual)
431         black.assert_equivalent(source, actual)
432         black.assert_stable(source, actual, black.FileMode())
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_comments6(self) -> None:
436         source, expected = read_data("comments6")
437         actual = fs(source)
438         self.assertFormatEqual(expected, actual)
439         black.assert_equivalent(source, actual)
440         black.assert_stable(source, actual, black.FileMode())
441
442     @patch("black.dump_to_file", dump_to_stderr)
443     def test_comments7(self) -> None:
444         source, expected = read_data("comments7")
445         actual = fs(source)
446         self.assertFormatEqual(expected, actual)
447         black.assert_equivalent(source, actual)
448         black.assert_stable(source, actual, black.FileMode())
449
450     @patch("black.dump_to_file", dump_to_stderr)
451     def test_comment_after_escaped_newline(self) -> None:
452         source, expected = read_data("comment_after_escaped_newline")
453         actual = fs(source)
454         self.assertFormatEqual(expected, actual)
455         black.assert_equivalent(source, actual)
456         black.assert_stable(source, actual, black.FileMode())
457
458     @patch("black.dump_to_file", dump_to_stderr)
459     def test_cantfit(self) -> None:
460         source, expected = read_data("cantfit")
461         actual = fs(source)
462         self.assertFormatEqual(expected, actual)
463         black.assert_equivalent(source, actual)
464         black.assert_stable(source, actual, black.FileMode())
465
466     @patch("black.dump_to_file", dump_to_stderr)
467     def test_import_spacing(self) -> None:
468         source, expected = read_data("import_spacing")
469         actual = fs(source)
470         self.assertFormatEqual(expected, actual)
471         black.assert_equivalent(source, actual)
472         black.assert_stable(source, actual, black.FileMode())
473
474     @patch("black.dump_to_file", dump_to_stderr)
475     def test_composition(self) -> None:
476         source, expected = read_data("composition")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         black.assert_equivalent(source, actual)
480         black.assert_stable(source, actual, black.FileMode())
481
482     @patch("black.dump_to_file", dump_to_stderr)
483     def test_empty_lines(self) -> None:
484         source, expected = read_data("empty_lines")
485         actual = fs(source)
486         self.assertFormatEqual(expected, actual)
487         black.assert_equivalent(source, actual)
488         black.assert_stable(source, actual, black.FileMode())
489
490     @patch("black.dump_to_file", dump_to_stderr)
491     def test_remove_parens(self) -> None:
492         source, expected = read_data("remove_parens")
493         actual = fs(source)
494         self.assertFormatEqual(expected, actual)
495         black.assert_equivalent(source, actual)
496         black.assert_stable(source, actual, black.FileMode())
497
498     @patch("black.dump_to_file", dump_to_stderr)
499     def test_string_prefixes(self) -> None:
500         source, expected = read_data("string_prefixes")
501         actual = fs(source)
502         self.assertFormatEqual(expected, actual)
503         black.assert_equivalent(source, actual)
504         black.assert_stable(source, actual, black.FileMode())
505
506     @patch("black.dump_to_file", dump_to_stderr)
507     def test_numeric_literals(self) -> None:
508         source, expected = read_data("numeric_literals")
509         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
510         actual = fs(source, mode=mode)
511         self.assertFormatEqual(expected, actual)
512         black.assert_equivalent(source, actual)
513         black.assert_stable(source, actual, mode)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_numeric_literals_ignoring_underscores(self) -> None:
517         source, expected = read_data("numeric_literals_skip_underscores")
518         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
519         actual = fs(source, mode=mode)
520         self.assertFormatEqual(expected, actual)
521         black.assert_equivalent(source, actual)
522         black.assert_stable(source, actual, mode)
523
524     @patch("black.dump_to_file", dump_to_stderr)
525     def test_numeric_literals_py2(self) -> None:
526         source, expected = read_data("numeric_literals_py2")
527         actual = fs(source)
528         self.assertFormatEqual(expected, actual)
529         black.assert_stable(source, actual, black.FileMode())
530
531     @patch("black.dump_to_file", dump_to_stderr)
532     def test_python2(self) -> None:
533         source, expected = read_data("python2")
534         actual = fs(source)
535         self.assertFormatEqual(expected, actual)
536         black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, black.FileMode())
538
539     @patch("black.dump_to_file", dump_to_stderr)
540     def test_python2_print_function(self) -> None:
541         source, expected = read_data("python2_print_function")
542         mode = black.FileMode(target_versions={TargetVersion.PY27})
543         actual = fs(source, mode=mode)
544         self.assertFormatEqual(expected, actual)
545         black.assert_equivalent(source, actual)
546         black.assert_stable(source, actual, mode)
547
548     @patch("black.dump_to_file", dump_to_stderr)
549     def test_python2_unicode_literals(self) -> None:
550         source, expected = read_data("python2_unicode_literals")
551         actual = fs(source)
552         self.assertFormatEqual(expected, actual)
553         black.assert_equivalent(source, actual)
554         black.assert_stable(source, actual, black.FileMode())
555
556     @patch("black.dump_to_file", dump_to_stderr)
557     def test_stub(self) -> None:
558         mode = black.FileMode(is_pyi=True)
559         source, expected = read_data("stub.pyi")
560         actual = fs(source, mode=mode)
561         self.assertFormatEqual(expected, actual)
562         black.assert_stable(source, actual, mode)
563
564     @patch("black.dump_to_file", dump_to_stderr)
565     def test_async_as_identifier(self) -> None:
566         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
567         source, expected = read_data("async_as_identifier")
568         actual = fs(source)
569         self.assertFormatEqual(expected, actual)
570         major, minor = sys.version_info[:2]
571         if major < 3 or (major <= 3 and minor < 7):
572             black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, black.FileMode())
574         # ensure black can parse this when the target is 3.6
575         self.invokeBlack([str(source_path), "--target-version", "py36"])
576         # but not on 3.7, because async/await is no longer an identifier
577         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
578
579     @patch("black.dump_to_file", dump_to_stderr)
580     def test_python37(self) -> None:
581         source_path = (THIS_DIR / "data" / "python37.py").resolve()
582         source, expected = read_data("python37")
583         actual = fs(source)
584         self.assertFormatEqual(expected, actual)
585         major, minor = sys.version_info[:2]
586         if major > 3 or (major == 3 and minor >= 7):
587             black.assert_equivalent(source, actual)
588         black.assert_stable(source, actual, black.FileMode())
589         # ensure black can parse this when the target is 3.7
590         self.invokeBlack([str(source_path), "--target-version", "py37"])
591         # but not on 3.6, because we use async as a reserved keyword
592         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
593
594     @patch("black.dump_to_file", dump_to_stderr)
595     def test_fmtonoff(self) -> None:
596         source, expected = read_data("fmtonoff")
597         actual = fs(source)
598         self.assertFormatEqual(expected, actual)
599         black.assert_equivalent(source, actual)
600         black.assert_stable(source, actual, black.FileMode())
601
602     @patch("black.dump_to_file", dump_to_stderr)
603     def test_fmtonoff2(self) -> None:
604         source, expected = read_data("fmtonoff2")
605         actual = fs(source)
606         self.assertFormatEqual(expected, actual)
607         black.assert_equivalent(source, actual)
608         black.assert_stable(source, actual, black.FileMode())
609
610     @patch("black.dump_to_file", dump_to_stderr)
611     def test_remove_empty_parentheses_after_class(self) -> None:
612         source, expected = read_data("class_blank_parentheses")
613         actual = fs(source)
614         self.assertFormatEqual(expected, actual)
615         black.assert_equivalent(source, actual)
616         black.assert_stable(source, actual, black.FileMode())
617
618     @patch("black.dump_to_file", dump_to_stderr)
619     def test_new_line_between_class_and_code(self) -> None:
620         source, expected = read_data("class_methods_new_line")
621         actual = fs(source)
622         self.assertFormatEqual(expected, actual)
623         black.assert_equivalent(source, actual)
624         black.assert_stable(source, actual, black.FileMode())
625
626     @patch("black.dump_to_file", dump_to_stderr)
627     def test_bracket_match(self) -> None:
628         source, expected = read_data("bracketmatch")
629         actual = fs(source)
630         self.assertFormatEqual(expected, actual)
631         black.assert_equivalent(source, actual)
632         black.assert_stable(source, actual, black.FileMode())
633
634     @patch("black.dump_to_file", dump_to_stderr)
635     def test_tuple_assign(self) -> None:
636         source, expected = read_data("tupleassign")
637         actual = fs(source)
638         self.assertFormatEqual(expected, actual)
639         black.assert_equivalent(source, actual)
640         black.assert_stable(source, actual, black.FileMode())
641
642     def test_tab_comment_indentation(self) -> None:
643         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
644         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
645         self.assertFormatEqual(contents_spc, fs(contents_spc))
646         self.assertFormatEqual(contents_spc, fs(contents_tab))
647
648         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
649         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
650         self.assertFormatEqual(contents_spc, fs(contents_spc))
651         self.assertFormatEqual(contents_spc, fs(contents_tab))
652
653         # mixed tabs and spaces (valid Python 2 code)
654         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
655         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
656         self.assertFormatEqual(contents_spc, fs(contents_spc))
657         self.assertFormatEqual(contents_spc, fs(contents_tab))
658
659         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
660         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
661         self.assertFormatEqual(contents_spc, fs(contents_spc))
662         self.assertFormatEqual(contents_spc, fs(contents_tab))
663
664     def test_report_verbose(self) -> None:
665         report = black.Report(verbose=True)
666         out_lines = []
667         err_lines = []
668
669         def out(msg: str, **kwargs: Any) -> None:
670             out_lines.append(msg)
671
672         def err(msg: str, **kwargs: Any) -> None:
673             err_lines.append(msg)
674
675         with patch("black.out", out), patch("black.err", err):
676             report.done(Path("f1"), black.Changed.NO)
677             self.assertEqual(len(out_lines), 1)
678             self.assertEqual(len(err_lines), 0)
679             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
680             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
681             self.assertEqual(report.return_code, 0)
682             report.done(Path("f2"), black.Changed.YES)
683             self.assertEqual(len(out_lines), 2)
684             self.assertEqual(len(err_lines), 0)
685             self.assertEqual(out_lines[-1], "reformatted f2")
686             self.assertEqual(
687                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
688             )
689             report.done(Path("f3"), black.Changed.CACHED)
690             self.assertEqual(len(out_lines), 3)
691             self.assertEqual(len(err_lines), 0)
692             self.assertEqual(
693                 out_lines[-1], "f3 wasn't modified on disk since last run."
694             )
695             self.assertEqual(
696                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
697             )
698             self.assertEqual(report.return_code, 0)
699             report.check = True
700             self.assertEqual(report.return_code, 1)
701             report.check = False
702             report.failed(Path("e1"), "boom")
703             self.assertEqual(len(out_lines), 3)
704             self.assertEqual(len(err_lines), 1)
705             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
706             self.assertEqual(
707                 unstyle(str(report)),
708                 "1 file reformatted, 2 files left unchanged, "
709                 "1 file failed to reformat.",
710             )
711             self.assertEqual(report.return_code, 123)
712             report.done(Path("f3"), black.Changed.YES)
713             self.assertEqual(len(out_lines), 4)
714             self.assertEqual(len(err_lines), 1)
715             self.assertEqual(out_lines[-1], "reformatted f3")
716             self.assertEqual(
717                 unstyle(str(report)),
718                 "2 files reformatted, 2 files left unchanged, "
719                 "1 file failed to reformat.",
720             )
721             self.assertEqual(report.return_code, 123)
722             report.failed(Path("e2"), "boom")
723             self.assertEqual(len(out_lines), 4)
724             self.assertEqual(len(err_lines), 2)
725             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
726             self.assertEqual(
727                 unstyle(str(report)),
728                 "2 files reformatted, 2 files left unchanged, "
729                 "2 files failed to reformat.",
730             )
731             self.assertEqual(report.return_code, 123)
732             report.path_ignored(Path("wat"), "no match")
733             self.assertEqual(len(out_lines), 5)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(out_lines[-1], "wat ignored: no match")
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files reformatted, 2 files left unchanged, "
739                 "2 files failed to reformat.",
740             )
741             self.assertEqual(report.return_code, 123)
742             report.done(Path("f4"), black.Changed.NO)
743             self.assertEqual(len(out_lines), 6)
744             self.assertEqual(len(err_lines), 2)
745             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
746             self.assertEqual(
747                 unstyle(str(report)),
748                 "2 files reformatted, 3 files left unchanged, "
749                 "2 files failed to reformat.",
750             )
751             self.assertEqual(report.return_code, 123)
752             report.check = True
753             self.assertEqual(
754                 unstyle(str(report)),
755                 "2 files would be reformatted, 3 files would be left unchanged, "
756                 "2 files would fail to reformat.",
757             )
758
759     def test_report_quiet(self) -> None:
760         report = black.Report(quiet=True)
761         out_lines = []
762         err_lines = []
763
764         def out(msg: str, **kwargs: Any) -> None:
765             out_lines.append(msg)
766
767         def err(msg: str, **kwargs: Any) -> None:
768             err_lines.append(msg)
769
770         with patch("black.out", out), patch("black.err", err):
771             report.done(Path("f1"), black.Changed.NO)
772             self.assertEqual(len(out_lines), 0)
773             self.assertEqual(len(err_lines), 0)
774             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
775             self.assertEqual(report.return_code, 0)
776             report.done(Path("f2"), black.Changed.YES)
777             self.assertEqual(len(out_lines), 0)
778             self.assertEqual(len(err_lines), 0)
779             self.assertEqual(
780                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
781             )
782             report.done(Path("f3"), black.Changed.CACHED)
783             self.assertEqual(len(out_lines), 0)
784             self.assertEqual(len(err_lines), 0)
785             self.assertEqual(
786                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
787             )
788             self.assertEqual(report.return_code, 0)
789             report.check = True
790             self.assertEqual(report.return_code, 1)
791             report.check = False
792             report.failed(Path("e1"), "boom")
793             self.assertEqual(len(out_lines), 0)
794             self.assertEqual(len(err_lines), 1)
795             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
796             self.assertEqual(
797                 unstyle(str(report)),
798                 "1 file reformatted, 2 files left unchanged, "
799                 "1 file failed to reformat.",
800             )
801             self.assertEqual(report.return_code, 123)
802             report.done(Path("f3"), black.Changed.YES)
803             self.assertEqual(len(out_lines), 0)
804             self.assertEqual(len(err_lines), 1)
805             self.assertEqual(
806                 unstyle(str(report)),
807                 "2 files reformatted, 2 files left unchanged, "
808                 "1 file failed to reformat.",
809             )
810             self.assertEqual(report.return_code, 123)
811             report.failed(Path("e2"), "boom")
812             self.assertEqual(len(out_lines), 0)
813             self.assertEqual(len(err_lines), 2)
814             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
815             self.assertEqual(
816                 unstyle(str(report)),
817                 "2 files reformatted, 2 files left unchanged, "
818                 "2 files failed to reformat.",
819             )
820             self.assertEqual(report.return_code, 123)
821             report.path_ignored(Path("wat"), "no match")
822             self.assertEqual(len(out_lines), 0)
823             self.assertEqual(len(err_lines), 2)
824             self.assertEqual(
825                 unstyle(str(report)),
826                 "2 files reformatted, 2 files left unchanged, "
827                 "2 files failed to reformat.",
828             )
829             self.assertEqual(report.return_code, 123)
830             report.done(Path("f4"), black.Changed.NO)
831             self.assertEqual(len(out_lines), 0)
832             self.assertEqual(len(err_lines), 2)
833             self.assertEqual(
834                 unstyle(str(report)),
835                 "2 files reformatted, 3 files left unchanged, "
836                 "2 files failed to reformat.",
837             )
838             self.assertEqual(report.return_code, 123)
839             report.check = True
840             self.assertEqual(
841                 unstyle(str(report)),
842                 "2 files would be reformatted, 3 files would be left unchanged, "
843                 "2 files would fail to reformat.",
844             )
845
846     def test_report_normal(self) -> None:
847         report = black.Report()
848         out_lines = []
849         err_lines = []
850
851         def out(msg: str, **kwargs: Any) -> None:
852             out_lines.append(msg)
853
854         def err(msg: str, **kwargs: Any) -> None:
855             err_lines.append(msg)
856
857         with patch("black.out", out), patch("black.err", err):
858             report.done(Path("f1"), black.Changed.NO)
859             self.assertEqual(len(out_lines), 0)
860             self.assertEqual(len(err_lines), 0)
861             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
862             self.assertEqual(report.return_code, 0)
863             report.done(Path("f2"), black.Changed.YES)
864             self.assertEqual(len(out_lines), 1)
865             self.assertEqual(len(err_lines), 0)
866             self.assertEqual(out_lines[-1], "reformatted f2")
867             self.assertEqual(
868                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
869             )
870             report.done(Path("f3"), black.Changed.CACHED)
871             self.assertEqual(len(out_lines), 1)
872             self.assertEqual(len(err_lines), 0)
873             self.assertEqual(out_lines[-1], "reformatted f2")
874             self.assertEqual(
875                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
876             )
877             self.assertEqual(report.return_code, 0)
878             report.check = True
879             self.assertEqual(report.return_code, 1)
880             report.check = False
881             report.failed(Path("e1"), "boom")
882             self.assertEqual(len(out_lines), 1)
883             self.assertEqual(len(err_lines), 1)
884             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
885             self.assertEqual(
886                 unstyle(str(report)),
887                 "1 file reformatted, 2 files left unchanged, "
888                 "1 file failed to reformat.",
889             )
890             self.assertEqual(report.return_code, 123)
891             report.done(Path("f3"), black.Changed.YES)
892             self.assertEqual(len(out_lines), 2)
893             self.assertEqual(len(err_lines), 1)
894             self.assertEqual(out_lines[-1], "reformatted f3")
895             self.assertEqual(
896                 unstyle(str(report)),
897                 "2 files reformatted, 2 files left unchanged, "
898                 "1 file failed to reformat.",
899             )
900             self.assertEqual(report.return_code, 123)
901             report.failed(Path("e2"), "boom")
902             self.assertEqual(len(out_lines), 2)
903             self.assertEqual(len(err_lines), 2)
904             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
905             self.assertEqual(
906                 unstyle(str(report)),
907                 "2 files reformatted, 2 files left unchanged, "
908                 "2 files failed to reformat.",
909             )
910             self.assertEqual(report.return_code, 123)
911             report.path_ignored(Path("wat"), "no match")
912             self.assertEqual(len(out_lines), 2)
913             self.assertEqual(len(err_lines), 2)
914             self.assertEqual(
915                 unstyle(str(report)),
916                 "2 files reformatted, 2 files left unchanged, "
917                 "2 files failed to reformat.",
918             )
919             self.assertEqual(report.return_code, 123)
920             report.done(Path("f4"), black.Changed.NO)
921             self.assertEqual(len(out_lines), 2)
922             self.assertEqual(len(err_lines), 2)
923             self.assertEqual(
924                 unstyle(str(report)),
925                 "2 files reformatted, 3 files left unchanged, "
926                 "2 files failed to reformat.",
927             )
928             self.assertEqual(report.return_code, 123)
929             report.check = True
930             self.assertEqual(
931                 unstyle(str(report)),
932                 "2 files would be reformatted, 3 files would be left unchanged, "
933                 "2 files would fail to reformat.",
934             )
935
936     def test_lib2to3_parse(self) -> None:
937         with self.assertRaises(black.InvalidInput):
938             black.lib2to3_parse("invalid syntax")
939
940         straddling = "x + y"
941         black.lib2to3_parse(straddling)
942         black.lib2to3_parse(straddling, {TargetVersion.PY27})
943         black.lib2to3_parse(straddling, {TargetVersion.PY36})
944         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
945
946         py2_only = "print x"
947         black.lib2to3_parse(py2_only)
948         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
949         with self.assertRaises(black.InvalidInput):
950             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
951         with self.assertRaises(black.InvalidInput):
952             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
953
954         py3_only = "exec(x, end=y)"
955         black.lib2to3_parse(py3_only)
956         with self.assertRaises(black.InvalidInput):
957             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
958         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
959         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
960
961     def test_get_features_used(self) -> None:
962         node = black.lib2to3_parse("def f(*, arg): ...\n")
963         self.assertEqual(black.get_features_used(node), set())
964         node = black.lib2to3_parse("def f(*, arg,): ...\n")
965         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
966         node = black.lib2to3_parse("f(*arg,)\n")
967         self.assertEqual(
968             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
969         )
970         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
971         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
972         node = black.lib2to3_parse("123_456\n")
973         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
974         node = black.lib2to3_parse("123456\n")
975         self.assertEqual(black.get_features_used(node), set())
976         source, expected = read_data("function")
977         node = black.lib2to3_parse(source)
978         expected_features = {
979             Feature.TRAILING_COMMA_IN_CALL,
980             Feature.TRAILING_COMMA_IN_DEF,
981             Feature.F_STRINGS,
982         }
983         self.assertEqual(black.get_features_used(node), expected_features)
984         node = black.lib2to3_parse(expected)
985         self.assertEqual(black.get_features_used(node), expected_features)
986         source, expected = read_data("expression")
987         node = black.lib2to3_parse(source)
988         self.assertEqual(black.get_features_used(node), set())
989         node = black.lib2to3_parse(expected)
990         self.assertEqual(black.get_features_used(node), set())
991
992     def test_get_future_imports(self) -> None:
993         node = black.lib2to3_parse("\n")
994         self.assertEqual(set(), black.get_future_imports(node))
995         node = black.lib2to3_parse("from __future__ import black\n")
996         self.assertEqual({"black"}, black.get_future_imports(node))
997         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
998         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
999         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1000         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1001         node = black.lib2to3_parse(
1002             "from __future__ import multiple\nfrom __future__ import imports\n"
1003         )
1004         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1005         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1006         self.assertEqual({"black"}, black.get_future_imports(node))
1007         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1008         self.assertEqual({"black"}, black.get_future_imports(node))
1009         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1010         self.assertEqual(set(), black.get_future_imports(node))
1011         node = black.lib2to3_parse("from some.module import black\n")
1012         self.assertEqual(set(), black.get_future_imports(node))
1013         node = black.lib2to3_parse(
1014             "from __future__ import unicode_literals as _unicode_literals"
1015         )
1016         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1017         node = black.lib2to3_parse(
1018             "from __future__ import unicode_literals as _lol, print"
1019         )
1020         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1021
1022     def test_debug_visitor(self) -> None:
1023         source, _ = read_data("debug_visitor.py")
1024         expected, _ = read_data("debug_visitor.out")
1025         out_lines = []
1026         err_lines = []
1027
1028         def out(msg: str, **kwargs: Any) -> None:
1029             out_lines.append(msg)
1030
1031         def err(msg: str, **kwargs: Any) -> None:
1032             err_lines.append(msg)
1033
1034         with patch("black.out", out), patch("black.err", err):
1035             black.DebugVisitor.show(source)
1036         actual = "\n".join(out_lines) + "\n"
1037         log_name = ""
1038         if expected != actual:
1039             log_name = black.dump_to_file(*out_lines)
1040         self.assertEqual(
1041             expected,
1042             actual,
1043             f"AST print out is different. Actual version dumped to {log_name}",
1044         )
1045
1046     def test_format_file_contents(self) -> None:
1047         empty = ""
1048         mode = black.FileMode()
1049         with self.assertRaises(black.NothingChanged):
1050             black.format_file_contents(empty, mode=mode, fast=False)
1051         just_nl = "\n"
1052         with self.assertRaises(black.NothingChanged):
1053             black.format_file_contents(just_nl, mode=mode, fast=False)
1054         same = "l = [1, 2, 3]\n"
1055         with self.assertRaises(black.NothingChanged):
1056             black.format_file_contents(same, mode=mode, fast=False)
1057         different = "l = [1,2,3]"
1058         expected = same
1059         actual = black.format_file_contents(different, mode=mode, fast=False)
1060         self.assertEqual(expected, actual)
1061         invalid = "return if you can"
1062         with self.assertRaises(black.InvalidInput) as e:
1063             black.format_file_contents(invalid, mode=mode, fast=False)
1064         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1065
1066     def test_endmarker(self) -> None:
1067         n = black.lib2to3_parse("\n")
1068         self.assertEqual(n.type, black.syms.file_input)
1069         self.assertEqual(len(n.children), 1)
1070         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1071
1072     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1073     def test_assertFormatEqual(self) -> None:
1074         out_lines = []
1075         err_lines = []
1076
1077         def out(msg: str, **kwargs: Any) -> None:
1078             out_lines.append(msg)
1079
1080         def err(msg: str, **kwargs: Any) -> None:
1081             err_lines.append(msg)
1082
1083         with patch("black.out", out), patch("black.err", err):
1084             with self.assertRaises(AssertionError):
1085                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1086
1087         out_str = "".join(out_lines)
1088         self.assertTrue("Expected tree:" in out_str)
1089         self.assertTrue("Actual tree:" in out_str)
1090         self.assertEqual("".join(err_lines), "")
1091
1092     def test_cache_broken_file(self) -> None:
1093         mode = black.FileMode()
1094         with cache_dir() as workspace:
1095             cache_file = black.get_cache_file(mode)
1096             with cache_file.open("w") as fobj:
1097                 fobj.write("this is not a pickle")
1098             self.assertEqual(black.read_cache(mode), {})
1099             src = (workspace / "test.py").resolve()
1100             with src.open("w") as fobj:
1101                 fobj.write("print('hello')")
1102             self.invokeBlack([str(src)])
1103             cache = black.read_cache(mode)
1104             self.assertIn(src, cache)
1105
1106     def test_cache_single_file_already_cached(self) -> None:
1107         mode = black.FileMode()
1108         with cache_dir() as workspace:
1109             src = (workspace / "test.py").resolve()
1110             with src.open("w") as fobj:
1111                 fobj.write("print('hello')")
1112             black.write_cache({}, [src], mode)
1113             self.invokeBlack([str(src)])
1114             with src.open("r") as fobj:
1115                 self.assertEqual(fobj.read(), "print('hello')")
1116
1117     @event_loop(close=False)
1118     def test_cache_multiple_files(self) -> None:
1119         mode = black.FileMode()
1120         with cache_dir() as workspace, patch(
1121             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1122         ):
1123             one = (workspace / "one.py").resolve()
1124             with one.open("w") as fobj:
1125                 fobj.write("print('hello')")
1126             two = (workspace / "two.py").resolve()
1127             with two.open("w") as fobj:
1128                 fobj.write("print('hello')")
1129             black.write_cache({}, [one], mode)
1130             self.invokeBlack([str(workspace)])
1131             with one.open("r") as fobj:
1132                 self.assertEqual(fobj.read(), "print('hello')")
1133             with two.open("r") as fobj:
1134                 self.assertEqual(fobj.read(), 'print("hello")\n')
1135             cache = black.read_cache(mode)
1136             self.assertIn(one, cache)
1137             self.assertIn(two, cache)
1138
1139     def test_no_cache_when_writeback_diff(self) -> None:
1140         mode = black.FileMode()
1141         with cache_dir() as workspace:
1142             src = (workspace / "test.py").resolve()
1143             with src.open("w") as fobj:
1144                 fobj.write("print('hello')")
1145             self.invokeBlack([str(src), "--diff"])
1146             cache_file = black.get_cache_file(mode)
1147             self.assertFalse(cache_file.exists())
1148
1149     def test_no_cache_when_stdin(self) -> None:
1150         mode = black.FileMode()
1151         with cache_dir():
1152             result = CliRunner().invoke(
1153                 black.main, ["-"], input=BytesIO(b"print('hello')")
1154             )
1155             self.assertEqual(result.exit_code, 0)
1156             cache_file = black.get_cache_file(mode)
1157             self.assertFalse(cache_file.exists())
1158
1159     def test_read_cache_no_cachefile(self) -> None:
1160         mode = black.FileMode()
1161         with cache_dir():
1162             self.assertEqual(black.read_cache(mode), {})
1163
1164     def test_write_cache_read_cache(self) -> None:
1165         mode = black.FileMode()
1166         with cache_dir() as workspace:
1167             src = (workspace / "test.py").resolve()
1168             src.touch()
1169             black.write_cache({}, [src], mode)
1170             cache = black.read_cache(mode)
1171             self.assertIn(src, cache)
1172             self.assertEqual(cache[src], black.get_cache_info(src))
1173
1174     def test_filter_cached(self) -> None:
1175         with TemporaryDirectory() as workspace:
1176             path = Path(workspace)
1177             uncached = (path / "uncached").resolve()
1178             cached = (path / "cached").resolve()
1179             cached_but_changed = (path / "changed").resolve()
1180             uncached.touch()
1181             cached.touch()
1182             cached_but_changed.touch()
1183             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1184             todo, done = black.filter_cached(
1185                 cache, {uncached, cached, cached_but_changed}
1186             )
1187             self.assertEqual(todo, {uncached, cached_but_changed})
1188             self.assertEqual(done, {cached})
1189
1190     def test_write_cache_creates_directory_if_needed(self) -> None:
1191         mode = black.FileMode()
1192         with cache_dir(exists=False) as workspace:
1193             self.assertFalse(workspace.exists())
1194             black.write_cache({}, [], mode)
1195             self.assertTrue(workspace.exists())
1196
1197     @event_loop(close=False)
1198     def test_failed_formatting_does_not_get_cached(self) -> None:
1199         mode = black.FileMode()
1200         with cache_dir() as workspace, patch(
1201             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1202         ):
1203             failing = (workspace / "failing.py").resolve()
1204             with failing.open("w") as fobj:
1205                 fobj.write("not actually python")
1206             clean = (workspace / "clean.py").resolve()
1207             with clean.open("w") as fobj:
1208                 fobj.write('print("hello")\n')
1209             self.invokeBlack([str(workspace)], exit_code=123)
1210             cache = black.read_cache(mode)
1211             self.assertNotIn(failing, cache)
1212             self.assertIn(clean, cache)
1213
1214     def test_write_cache_write_fail(self) -> None:
1215         mode = black.FileMode()
1216         with cache_dir(), patch.object(Path, "open") as mock:
1217             mock.side_effect = OSError
1218             black.write_cache({}, [], mode)
1219
1220     @event_loop(close=False)
1221     def test_check_diff_use_together(self) -> None:
1222         with cache_dir():
1223             # Files which will be reformatted.
1224             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1225             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1226             # Files which will not be reformatted.
1227             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1228             self.invokeBlack([str(src2), "--diff", "--check"])
1229             # Multi file command.
1230             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1231
1232     def test_no_files(self) -> None:
1233         with cache_dir():
1234             # Without an argument, black exits with error code 0.
1235             self.invokeBlack([])
1236
1237     def test_broken_symlink(self) -> None:
1238         with cache_dir() as workspace:
1239             symlink = workspace / "broken_link.py"
1240             try:
1241                 symlink.symlink_to("nonexistent.py")
1242             except OSError as e:
1243                 self.skipTest(f"Can't create symlinks: {e}")
1244             self.invokeBlack([str(workspace.resolve())])
1245
1246     def test_read_cache_line_lengths(self) -> None:
1247         mode = black.FileMode()
1248         short_mode = black.FileMode(line_length=1)
1249         with cache_dir() as workspace:
1250             path = (workspace / "file.py").resolve()
1251             path.touch()
1252             black.write_cache({}, [path], mode)
1253             one = black.read_cache(mode)
1254             self.assertIn(path, one)
1255             two = black.read_cache(short_mode)
1256             self.assertNotIn(path, two)
1257
1258     def test_single_file_force_pyi(self) -> None:
1259         reg_mode = black.FileMode()
1260         pyi_mode = black.FileMode(is_pyi=True)
1261         contents, expected = read_data("force_pyi")
1262         with cache_dir() as workspace:
1263             path = (workspace / "file.py").resolve()
1264             with open(path, "w") as fh:
1265                 fh.write(contents)
1266             self.invokeBlack([str(path), "--pyi"])
1267             with open(path, "r") as fh:
1268                 actual = fh.read()
1269             # verify cache with --pyi is separate
1270             pyi_cache = black.read_cache(pyi_mode)
1271             self.assertIn(path, pyi_cache)
1272             normal_cache = black.read_cache(reg_mode)
1273             self.assertNotIn(path, normal_cache)
1274         self.assertEqual(actual, expected)
1275
1276     @event_loop(close=False)
1277     def test_multi_file_force_pyi(self) -> None:
1278         reg_mode = black.FileMode()
1279         pyi_mode = black.FileMode(is_pyi=True)
1280         contents, expected = read_data("force_pyi")
1281         with cache_dir() as workspace:
1282             paths = [
1283                 (workspace / "file1.py").resolve(),
1284                 (workspace / "file2.py").resolve(),
1285             ]
1286             for path in paths:
1287                 with open(path, "w") as fh:
1288                     fh.write(contents)
1289             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1290             for path in paths:
1291                 with open(path, "r") as fh:
1292                     actual = fh.read()
1293                 self.assertEqual(actual, expected)
1294             # verify cache with --pyi is separate
1295             pyi_cache = black.read_cache(pyi_mode)
1296             normal_cache = black.read_cache(reg_mode)
1297             for path in paths:
1298                 self.assertIn(path, pyi_cache)
1299                 self.assertNotIn(path, normal_cache)
1300
1301     def test_pipe_force_pyi(self) -> None:
1302         source, expected = read_data("force_pyi")
1303         result = CliRunner().invoke(
1304             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1305         )
1306         self.assertEqual(result.exit_code, 0)
1307         actual = result.output
1308         self.assertFormatEqual(actual, expected)
1309
1310     def test_single_file_force_py36(self) -> None:
1311         reg_mode = black.FileMode()
1312         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1313         source, expected = read_data("force_py36")
1314         with cache_dir() as workspace:
1315             path = (workspace / "file.py").resolve()
1316             with open(path, "w") as fh:
1317                 fh.write(source)
1318             self.invokeBlack([str(path), *PY36_ARGS])
1319             with open(path, "r") as fh:
1320                 actual = fh.read()
1321             # verify cache with --target-version is separate
1322             py36_cache = black.read_cache(py36_mode)
1323             self.assertIn(path, py36_cache)
1324             normal_cache = black.read_cache(reg_mode)
1325             self.assertNotIn(path, normal_cache)
1326         self.assertEqual(actual, expected)
1327
1328     @event_loop(close=False)
1329     def test_multi_file_force_py36(self) -> None:
1330         reg_mode = black.FileMode()
1331         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1332         source, expected = read_data("force_py36")
1333         with cache_dir() as workspace:
1334             paths = [
1335                 (workspace / "file1.py").resolve(),
1336                 (workspace / "file2.py").resolve(),
1337             ]
1338             for path in paths:
1339                 with open(path, "w") as fh:
1340                     fh.write(source)
1341             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1342             for path in paths:
1343                 with open(path, "r") as fh:
1344                     actual = fh.read()
1345                 self.assertEqual(actual, expected)
1346             # verify cache with --target-version is separate
1347             pyi_cache = black.read_cache(py36_mode)
1348             normal_cache = black.read_cache(reg_mode)
1349             for path in paths:
1350                 self.assertIn(path, pyi_cache)
1351                 self.assertNotIn(path, normal_cache)
1352
1353     def test_pipe_force_py36(self) -> None:
1354         source, expected = read_data("force_py36")
1355         result = CliRunner().invoke(
1356             black.main,
1357             ["-", "-q", "--target-version=py36"],
1358             input=BytesIO(source.encode("utf8")),
1359         )
1360         self.assertEqual(result.exit_code, 0)
1361         actual = result.output
1362         self.assertFormatEqual(actual, expected)
1363
1364     def test_include_exclude(self) -> None:
1365         path = THIS_DIR / "data" / "include_exclude_tests"
1366         include = re.compile(r"\.pyi?$")
1367         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1368         report = black.Report()
1369         sources: List[Path] = []
1370         expected = [
1371             Path(path / "b/dont_exclude/a.py"),
1372             Path(path / "b/dont_exclude/a.pyi"),
1373         ]
1374         this_abs = THIS_DIR.resolve()
1375         sources.extend(
1376             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1377         )
1378         self.assertEqual(sorted(expected), sorted(sources))
1379
1380     def test_empty_include(self) -> None:
1381         path = THIS_DIR / "data" / "include_exclude_tests"
1382         report = black.Report()
1383         empty = re.compile(r"")
1384         sources: List[Path] = []
1385         expected = [
1386             Path(path / "b/exclude/a.pie"),
1387             Path(path / "b/exclude/a.py"),
1388             Path(path / "b/exclude/a.pyi"),
1389             Path(path / "b/dont_exclude/a.pie"),
1390             Path(path / "b/dont_exclude/a.py"),
1391             Path(path / "b/dont_exclude/a.pyi"),
1392             Path(path / "b/.definitely_exclude/a.pie"),
1393             Path(path / "b/.definitely_exclude/a.py"),
1394             Path(path / "b/.definitely_exclude/a.pyi"),
1395         ]
1396         this_abs = THIS_DIR.resolve()
1397         sources.extend(
1398             black.gen_python_files_in_dir(
1399                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1400             )
1401         )
1402         self.assertEqual(sorted(expected), sorted(sources))
1403
1404     def test_empty_exclude(self) -> None:
1405         path = THIS_DIR / "data" / "include_exclude_tests"
1406         report = black.Report()
1407         empty = re.compile(r"")
1408         sources: List[Path] = []
1409         expected = [
1410             Path(path / "b/dont_exclude/a.py"),
1411             Path(path / "b/dont_exclude/a.pyi"),
1412             Path(path / "b/exclude/a.py"),
1413             Path(path / "b/exclude/a.pyi"),
1414             Path(path / "b/.definitely_exclude/a.py"),
1415             Path(path / "b/.definitely_exclude/a.pyi"),
1416         ]
1417         this_abs = THIS_DIR.resolve()
1418         sources.extend(
1419             black.gen_python_files_in_dir(
1420                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1421             )
1422         )
1423         self.assertEqual(sorted(expected), sorted(sources))
1424
1425     def test_invalid_include_exclude(self) -> None:
1426         for option in ["--include", "--exclude"]:
1427             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1428
1429     def test_preserves_line_endings(self) -> None:
1430         with TemporaryDirectory() as workspace:
1431             test_file = Path(workspace) / "test.py"
1432             for nl in ["\n", "\r\n"]:
1433                 contents = nl.join(["def f(  ):", "    pass"])
1434                 test_file.write_bytes(contents.encode())
1435                 ff(test_file, write_back=black.WriteBack.YES)
1436                 updated_contents: bytes = test_file.read_bytes()
1437                 self.assertIn(nl.encode(), updated_contents)
1438                 if nl == "\n":
1439                     self.assertNotIn(b"\r\n", updated_contents)
1440
1441     def test_preserves_line_endings_via_stdin(self) -> None:
1442         for nl in ["\n", "\r\n"]:
1443             contents = nl.join(["def f(  ):", "    pass"])
1444             runner = BlackRunner()
1445             result = runner.invoke(
1446                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1447             )
1448             self.assertEqual(result.exit_code, 0)
1449             output = runner.stdout_bytes
1450             self.assertIn(nl.encode("utf8"), output)
1451             if nl == "\n":
1452                 self.assertNotIn(b"\r\n", output)
1453
1454     def test_assert_equivalent_different_asts(self) -> None:
1455         with self.assertRaises(AssertionError):
1456             black.assert_equivalent("{}", "None")
1457
1458     def test_symlink_out_of_root_directory(self) -> None:
1459         path = MagicMock()
1460         root = THIS_DIR
1461         child = MagicMock()
1462         include = re.compile(black.DEFAULT_INCLUDES)
1463         exclude = re.compile(black.DEFAULT_EXCLUDES)
1464         report = black.Report()
1465         # `child` should behave like a symlink which resolved path is clearly
1466         # outside of the `root` directory.
1467         path.iterdir.return_value = [child]
1468         child.resolve.return_value = Path("/a/b/c")
1469         child.is_symlink.return_value = True
1470         try:
1471             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1472         except ValueError as ve:
1473             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1474         path.iterdir.assert_called_once()
1475         child.resolve.assert_called_once()
1476         child.is_symlink.assert_called_once()
1477         # `child` should behave like a strange file which resolved path is clearly
1478         # outside of the `root` directory.
1479         child.is_symlink.return_value = False
1480         with self.assertRaises(ValueError):
1481             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1482         path.iterdir.assert_called()
1483         self.assertEqual(path.iterdir.call_count, 2)
1484         child.resolve.assert_called()
1485         self.assertEqual(child.resolve.call_count, 2)
1486         child.is_symlink.assert_called()
1487         self.assertEqual(child.is_symlink.call_count, 2)
1488
1489     def test_shhh_click(self) -> None:
1490         try:
1491             from click import _unicodefun  # type: ignore
1492         except ModuleNotFoundError:
1493             self.skipTest("Incompatible Click version")
1494         if not hasattr(_unicodefun, "_verify_python3_env"):
1495             self.skipTest("Incompatible Click version")
1496         # First, let's see if Click is crashing with a preferred ASCII charset.
1497         with patch("locale.getpreferredencoding") as gpe:
1498             gpe.return_value = "ASCII"
1499             with self.assertRaises(RuntimeError):
1500                 _unicodefun._verify_python3_env()
1501         # Now, let's silence Click...
1502         black.patch_click()
1503         # ...and confirm it's silent.
1504         with patch("locale.getpreferredencoding") as gpe:
1505             gpe.return_value = "ASCII"
1506             try:
1507                 _unicodefun._verify_python3_env()
1508             except RuntimeError as re:
1509                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1510
1511     def test_root_logger_not_used_directly(self) -> None:
1512         def fail(*args: Any, **kwargs: Any) -> None:
1513             self.fail("Record created with root logger")
1514
1515         with patch.multiple(
1516             logging.root,
1517             debug=fail,
1518             info=fail,
1519             warning=fail,
1520             error=fail,
1521             critical=fail,
1522             log=fail,
1523         ):
1524             ff(THIS_FILE)
1525
1526     # TODO: remove these decorators once the below is released
1527     # https://github.com/aio-libs/aiohttp/pull/3727
1528     @skip_if_exception("ClientOSError")
1529     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1530     @async_test
1531     async def test_blackd_request_needs_formatting(self) -> None:
1532         app = blackd.make_app()
1533         async with TestClient(TestServer(app)) as client:
1534             response = await client.post("/", data=b"print('hello world')")
1535             self.assertEqual(response.status, 200)
1536             self.assertEqual(response.charset, "utf8")
1537             self.assertEqual(await response.read(), b'print("hello world")\n')
1538
1539     @skip_if_exception("ClientOSError")
1540     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1541     @async_test
1542     async def test_blackd_request_no_change(self) -> None:
1543         app = blackd.make_app()
1544         async with TestClient(TestServer(app)) as client:
1545             response = await client.post("/", data=b'print("hello world")\n')
1546             self.assertEqual(response.status, 204)
1547             self.assertEqual(await response.read(), b"")
1548
1549     @skip_if_exception("ClientOSError")
1550     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1551     @async_test
1552     async def test_blackd_request_syntax_error(self) -> None:
1553         app = blackd.make_app()
1554         async with TestClient(TestServer(app)) as client:
1555             response = await client.post("/", data=b"what even ( is")
1556             self.assertEqual(response.status, 400)
1557             content = await response.text()
1558             self.assertTrue(
1559                 content.startswith("Cannot parse"),
1560                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1561             )
1562
1563     @skip_if_exception("ClientOSError")
1564     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1565     @async_test
1566     async def test_blackd_unsupported_version(self) -> None:
1567         app = blackd.make_app()
1568         async with TestClient(TestServer(app)) as client:
1569             response = await client.post(
1570                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1571             )
1572             self.assertEqual(response.status, 501)
1573
1574     @skip_if_exception("ClientOSError")
1575     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1576     @async_test
1577     async def test_blackd_supported_version(self) -> None:
1578         app = blackd.make_app()
1579         async with TestClient(TestServer(app)) as client:
1580             response = await client.post(
1581                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1582             )
1583             self.assertEqual(response.status, 200)
1584
1585     @skip_if_exception("ClientOSError")
1586     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1587     @async_test
1588     async def test_blackd_invalid_python_variant(self) -> None:
1589         app = blackd.make_app()
1590         async with TestClient(TestServer(app)) as client:
1591
1592             async def check(header_value: str, expected_status: int = 400) -> None:
1593                 response = await client.post(
1594                     "/",
1595                     data=b"what",
1596                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1597                 )
1598                 self.assertEqual(response.status, expected_status)
1599
1600             await check("lol")
1601             await check("ruby3.5")
1602             await check("pyi3.6")
1603             await check("py1.5")
1604             await check("2.8")
1605             await check("py2.8")
1606             await check("3.0")
1607             await check("pypy3.0")
1608             await check("jython3.4")
1609
1610     @skip_if_exception("ClientOSError")
1611     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1612     @async_test
1613     async def test_blackd_pyi(self) -> None:
1614         app = blackd.make_app()
1615         async with TestClient(TestServer(app)) as client:
1616             source, expected = read_data("stub.pyi")
1617             response = await client.post(
1618                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1619             )
1620             self.assertEqual(response.status, 200)
1621             self.assertEqual(await response.text(), expected)
1622
1623     @skip_if_exception("ClientOSError")
1624     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1625     @async_test
1626     async def test_blackd_python_variant(self) -> None:
1627         app = blackd.make_app()
1628         code = (
1629             "def f(\n"
1630             "    and_has_a_bunch_of,\n"
1631             "    very_long_arguments_too,\n"
1632             "    and_lots_of_them_as_well_lol,\n"
1633             "    **and_very_long_keyword_arguments\n"
1634             "):\n"
1635             "    pass\n"
1636         )
1637         async with TestClient(TestServer(app)) as client:
1638
1639             async def check(header_value: str, expected_status: int) -> None:
1640                 response = await client.post(
1641                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1642                 )
1643                 self.assertEqual(response.status, expected_status)
1644
1645             await check("3.6", 200)
1646             await check("py3.6", 200)
1647             await check("3.6,3.7", 200)
1648             await check("3.6,py3.7", 200)
1649
1650             await check("2", 204)
1651             await check("2.7", 204)
1652             await check("py2.7", 204)
1653             await check("3.4", 204)
1654             await check("py3.4", 204)
1655
1656     @skip_if_exception("ClientOSError")
1657     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1658     @async_test
1659     async def test_blackd_line_length(self) -> None:
1660         app = blackd.make_app()
1661         async with TestClient(TestServer(app)) as client:
1662             response = await client.post(
1663                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1664             )
1665             self.assertEqual(response.status, 200)
1666
1667     @skip_if_exception("ClientOSError")
1668     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1669     @async_test
1670     async def test_blackd_invalid_line_length(self) -> None:
1671         app = blackd.make_app()
1672         async with TestClient(TestServer(app)) as client:
1673             response = await client.post(
1674                 "/",
1675                 data=b'print("hello")\n',
1676                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1677             )
1678             self.assertEqual(response.status, 400)
1679
1680     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1681     def test_blackd_main(self) -> None:
1682         with patch("blackd.web.run_app"):
1683             result = CliRunner().invoke(blackd.main, [])
1684             if result.exception is not None:
1685                 raise result.exception
1686             self.assertEqual(result.exit_code, 0)
1687
1688
1689 if __name__ == "__main__":
1690     unittest.main(module="test_black")