]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix unstable format involving backslash + whitespace at beginning of file (#948)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 @contextmanager
115 def skip_if_exception(e: str) -> Iterator[None]:
116     try:
117         yield
118     except Exception as exc:
119         if exc.__class__.__name__ == e:
120             unittest.skip(f"Encountered expected exception {exc}, skipping")
121
122
123 class BlackRunner(CliRunner):
124     """Modify CliRunner so that stderr is not merged with stdout.
125
126     This is a hack that can be removed once we depend on Click 7.x"""
127
128     def __init__(self) -> None:
129         self.stderrbuf = BytesIO()
130         self.stdoutbuf = BytesIO()
131         self.stdout_bytes = b""
132         self.stderr_bytes = b""
133         super().__init__()
134
135     @contextmanager
136     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
137         with super().isolation(*args, **kwargs) as output:
138             try:
139                 hold_stderr = sys.stderr
140                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
141                 yield output
142             finally:
143                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
144                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
145                 sys.stderr = hold_stderr
146
147
148 class BlackTestCase(unittest.TestCase):
149     maxDiff = None
150
151     def assertFormatEqual(self, expected: str, actual: str) -> None:
152         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
153             bdv: black.DebugVisitor[Any]
154             black.out("Expected tree:", fg="green")
155             try:
156                 exp_node = black.lib2to3_parse(expected)
157                 bdv = black.DebugVisitor()
158                 list(bdv.visit(exp_node))
159             except Exception as ve:
160                 black.err(str(ve))
161             black.out("Actual tree:", fg="red")
162             try:
163                 exp_node = black.lib2to3_parse(actual)
164                 bdv = black.DebugVisitor()
165                 list(bdv.visit(exp_node))
166             except Exception as ve:
167                 black.err(str(ve))
168         self.assertEqual(expected, actual)
169
170     def invokeBlack(
171         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
172     ) -> None:
173         runner = BlackRunner()
174         if ignore_config:
175             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
176         result = runner.invoke(black.main, args)
177         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
178
179     @patch("black.dump_to_file", dump_to_stderr)
180     def test_empty(self) -> None:
181         source = expected = ""
182         actual = fs(source)
183         self.assertFormatEqual(expected, actual)
184         black.assert_equivalent(source, actual)
185         black.assert_stable(source, actual, black.FileMode())
186
187     def test_empty_ff(self) -> None:
188         expected = ""
189         tmp_file = Path(black.dump_to_file())
190         try:
191             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
192             with open(tmp_file, encoding="utf8") as f:
193                 actual = f.read()
194         finally:
195             os.unlink(tmp_file)
196         self.assertFormatEqual(expected, actual)
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_self(self) -> None:
200         source, expected = read_data("test_black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_FILE))
206
207     @patch("black.dump_to_file", dump_to_stderr)
208     def test_black(self) -> None:
209         source, expected = read_data("../black", data=False)
210         actual = fs(source)
211         self.assertFormatEqual(expected, actual)
212         black.assert_equivalent(source, actual)
213         black.assert_stable(source, actual, black.FileMode())
214         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
215
216     def test_piping(self) -> None:
217         source, expected = read_data("../black", data=False)
218         result = BlackRunner().invoke(
219             black.main,
220             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
221             input=BytesIO(source.encode("utf8")),
222         )
223         self.assertEqual(result.exit_code, 0)
224         self.assertFormatEqual(expected, result.output)
225         black.assert_equivalent(source, result.output)
226         black.assert_stable(source, result.output, black.FileMode())
227
228     def test_piping_diff(self) -> None:
229         diff_header = re.compile(
230             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
231             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
232         )
233         source, _ = read_data("expression.py")
234         expected, _ = read_data("expression.diff")
235         config = THIS_DIR / "data" / "empty_pyproject.toml"
236         args = [
237             "-",
238             "--fast",
239             f"--line-length={black.DEFAULT_LINE_LENGTH}",
240             "--diff",
241             f"--config={config}",
242         ]
243         result = BlackRunner().invoke(
244             black.main, args, input=BytesIO(source.encode("utf8"))
245         )
246         self.assertEqual(result.exit_code, 0)
247         actual = diff_header.sub("[Deterministic header]", result.output)
248         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
249         self.assertEqual(expected, actual)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_setup(self) -> None:
253         source, expected = read_data("../setup", data=False)
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
259
260     @patch("black.dump_to_file", dump_to_stderr)
261     def test_function(self) -> None:
262         source, expected = read_data("function")
263         actual = fs(source)
264         self.assertFormatEqual(expected, actual)
265         black.assert_equivalent(source, actual)
266         black.assert_stable(source, actual, black.FileMode())
267
268     @patch("black.dump_to_file", dump_to_stderr)
269     def test_function2(self) -> None:
270         source, expected = read_data("function2")
271         actual = fs(source)
272         self.assertFormatEqual(expected, actual)
273         black.assert_equivalent(source, actual)
274         black.assert_stable(source, actual, black.FileMode())
275
276     @patch("black.dump_to_file", dump_to_stderr)
277     def test_function_trailing_comma(self) -> None:
278         source, expected = read_data("function_trailing_comma")
279         actual = fs(source)
280         self.assertFormatEqual(expected, actual)
281         black.assert_equivalent(source, actual)
282         black.assert_stable(source, actual, black.FileMode())
283
284     @patch("black.dump_to_file", dump_to_stderr)
285     def test_expression(self) -> None:
286         source, expected = read_data("expression")
287         actual = fs(source)
288         self.assertFormatEqual(expected, actual)
289         black.assert_equivalent(source, actual)
290         black.assert_stable(source, actual, black.FileMode())
291
292     @patch("black.dump_to_file", dump_to_stderr)
293     def test_pep_572(self) -> None:
294         source, expected = read_data("pep_572")
295         actual = fs(source)
296         self.assertFormatEqual(expected, actual)
297         black.assert_stable(source, actual, black.FileMode())
298         if sys.version_info >= (3, 8):
299             black.assert_equivalent(source, actual)
300
301     def test_pep_572_version_detection(self) -> None:
302         source, _ = read_data("pep_572")
303         root = black.lib2to3_parse(source)
304         features = black.get_features_used(root)
305         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
306         versions = black.detect_target_versions(root)
307         self.assertIn(black.TargetVersion.PY38, versions)
308
309     def test_expression_ff(self) -> None:
310         source, expected = read_data("expression")
311         tmp_file = Path(black.dump_to_file(source))
312         try:
313             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
314             with open(tmp_file, encoding="utf8") as f:
315                 actual = f.read()
316         finally:
317             os.unlink(tmp_file)
318         self.assertFormatEqual(expected, actual)
319         with patch("black.dump_to_file", dump_to_stderr):
320             black.assert_equivalent(source, actual)
321             black.assert_stable(source, actual, black.FileMode())
322
323     def test_expression_diff(self) -> None:
324         source, _ = read_data("expression.py")
325         expected, _ = read_data("expression.diff")
326         tmp_file = Path(black.dump_to_file(source))
327         diff_header = re.compile(
328             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
329             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
330         )
331         try:
332             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
333             self.assertEqual(result.exit_code, 0)
334         finally:
335             os.unlink(tmp_file)
336         actual = result.output
337         actual = diff_header.sub("[Deterministic header]", actual)
338         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
339         if expected != actual:
340             dump = black.dump_to_file(actual)
341             msg = (
342                 f"Expected diff isn't equal to the actual. If you made changes "
343                 f"to expression.py and this is an anticipated difference, "
344                 f"overwrite tests/data/expression.diff with {dump}"
345             )
346             self.assertEqual(expected, actual, msg)
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_fstring(self) -> None:
350         source, expected = read_data("fstring")
351         actual = fs(source)
352         self.assertFormatEqual(expected, actual)
353         black.assert_equivalent(source, actual)
354         black.assert_stable(source, actual, black.FileMode())
355
356     @patch("black.dump_to_file", dump_to_stderr)
357     def test_pep_570(self) -> None:
358         source, expected = read_data("pep_570")
359         actual = fs(source)
360         self.assertFormatEqual(expected, actual)
361         black.assert_stable(source, actual, black.FileMode())
362         if sys.version_info >= (3, 8):
363             black.assert_equivalent(source, actual)
364
365     def test_detect_pos_only_arguments(self) -> None:
366         source, _ = read_data("pep_570")
367         root = black.lib2to3_parse(source)
368         features = black.get_features_used(root)
369         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
370         versions = black.detect_target_versions(root)
371         self.assertIn(black.TargetVersion.PY38, versions)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_string_quotes(self) -> None:
375         source, expected = read_data("string_quotes")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, black.FileMode())
380         mode = black.FileMode(string_normalization=False)
381         not_normalized = fs(source, mode=mode)
382         self.assertFormatEqual(source, not_normalized)
383         black.assert_equivalent(source, not_normalized)
384         black.assert_stable(source, not_normalized, mode=mode)
385
386     @patch("black.dump_to_file", dump_to_stderr)
387     def test_slices(self) -> None:
388         source, expected = read_data("slices")
389         actual = fs(source)
390         self.assertFormatEqual(expected, actual)
391         black.assert_equivalent(source, actual)
392         black.assert_stable(source, actual, black.FileMode())
393
394     @patch("black.dump_to_file", dump_to_stderr)
395     def test_comments(self) -> None:
396         source, expected = read_data("comments")
397         actual = fs(source)
398         self.assertFormatEqual(expected, actual)
399         black.assert_equivalent(source, actual)
400         black.assert_stable(source, actual, black.FileMode())
401
402     @patch("black.dump_to_file", dump_to_stderr)
403     def test_comments2(self) -> None:
404         source, expected = read_data("comments2")
405         actual = fs(source)
406         self.assertFormatEqual(expected, actual)
407         black.assert_equivalent(source, actual)
408         black.assert_stable(source, actual, black.FileMode())
409
410     @patch("black.dump_to_file", dump_to_stderr)
411     def test_comments3(self) -> None:
412         source, expected = read_data("comments3")
413         actual = fs(source)
414         self.assertFormatEqual(expected, actual)
415         black.assert_equivalent(source, actual)
416         black.assert_stable(source, actual, black.FileMode())
417
418     @patch("black.dump_to_file", dump_to_stderr)
419     def test_comments4(self) -> None:
420         source, expected = read_data("comments4")
421         actual = fs(source)
422         self.assertFormatEqual(expected, actual)
423         black.assert_equivalent(source, actual)
424         black.assert_stable(source, actual, black.FileMode())
425
426     @patch("black.dump_to_file", dump_to_stderr)
427     def test_comments5(self) -> None:
428         source, expected = read_data("comments5")
429         actual = fs(source)
430         self.assertFormatEqual(expected, actual)
431         black.assert_equivalent(source, actual)
432         black.assert_stable(source, actual, black.FileMode())
433
434     @patch("black.dump_to_file", dump_to_stderr)
435     def test_comments6(self) -> None:
436         source, expected = read_data("comments6")
437         actual = fs(source)
438         self.assertFormatEqual(expected, actual)
439         black.assert_equivalent(source, actual)
440         black.assert_stable(source, actual, black.FileMode())
441
442     @patch("black.dump_to_file", dump_to_stderr)
443     def test_comments7(self) -> None:
444         source, expected = read_data("comments7")
445         actual = fs(source)
446         self.assertFormatEqual(expected, actual)
447         black.assert_equivalent(source, actual)
448         black.assert_stable(source, actual, black.FileMode())
449
450     @patch("black.dump_to_file", dump_to_stderr)
451     def test_comment_after_escaped_newline(self) -> None:
452         source, expected = read_data("comment_after_escaped_newline")
453         actual = fs(source)
454         self.assertFormatEqual(expected, actual)
455         black.assert_equivalent(source, actual)
456         black.assert_stable(source, actual, black.FileMode())
457
458     @patch("black.dump_to_file", dump_to_stderr)
459     def test_cantfit(self) -> None:
460         source, expected = read_data("cantfit")
461         actual = fs(source)
462         self.assertFormatEqual(expected, actual)
463         black.assert_equivalent(source, actual)
464         black.assert_stable(source, actual, black.FileMode())
465
466     @patch("black.dump_to_file", dump_to_stderr)
467     def test_import_spacing(self) -> None:
468         source, expected = read_data("import_spacing")
469         actual = fs(source)
470         self.assertFormatEqual(expected, actual)
471         black.assert_equivalent(source, actual)
472         black.assert_stable(source, actual, black.FileMode())
473
474     @patch("black.dump_to_file", dump_to_stderr)
475     def test_composition(self) -> None:
476         source, expected = read_data("composition")
477         actual = fs(source)
478         self.assertFormatEqual(expected, actual)
479         black.assert_equivalent(source, actual)
480         black.assert_stable(source, actual, black.FileMode())
481
482     @patch("black.dump_to_file", dump_to_stderr)
483     def test_empty_lines(self) -> None:
484         source, expected = read_data("empty_lines")
485         actual = fs(source)
486         self.assertFormatEqual(expected, actual)
487         black.assert_equivalent(source, actual)
488         black.assert_stable(source, actual, black.FileMode())
489
490     @patch("black.dump_to_file", dump_to_stderr)
491     def test_remove_parens(self) -> None:
492         source, expected = read_data("remove_parens")
493         actual = fs(source)
494         self.assertFormatEqual(expected, actual)
495         black.assert_equivalent(source, actual)
496         black.assert_stable(source, actual, black.FileMode())
497
498     @patch("black.dump_to_file", dump_to_stderr)
499     def test_string_prefixes(self) -> None:
500         source, expected = read_data("string_prefixes")
501         actual = fs(source)
502         self.assertFormatEqual(expected, actual)
503         black.assert_equivalent(source, actual)
504         black.assert_stable(source, actual, black.FileMode())
505
506     @patch("black.dump_to_file", dump_to_stderr)
507     def test_numeric_literals(self) -> None:
508         source, expected = read_data("numeric_literals")
509         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
510         actual = fs(source, mode=mode)
511         self.assertFormatEqual(expected, actual)
512         black.assert_equivalent(source, actual)
513         black.assert_stable(source, actual, mode)
514
515     @patch("black.dump_to_file", dump_to_stderr)
516     def test_numeric_literals_ignoring_underscores(self) -> None:
517         source, expected = read_data("numeric_literals_skip_underscores")
518         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
519         actual = fs(source, mode=mode)
520         self.assertFormatEqual(expected, actual)
521         black.assert_equivalent(source, actual)
522         black.assert_stable(source, actual, mode)
523
524     @patch("black.dump_to_file", dump_to_stderr)
525     def test_numeric_literals_py2(self) -> None:
526         source, expected = read_data("numeric_literals_py2")
527         actual = fs(source)
528         self.assertFormatEqual(expected, actual)
529         black.assert_stable(source, actual, black.FileMode())
530
531     @patch("black.dump_to_file", dump_to_stderr)
532     def test_python2(self) -> None:
533         source, expected = read_data("python2")
534         actual = fs(source)
535         self.assertFormatEqual(expected, actual)
536         black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, black.FileMode())
538
539     @patch("black.dump_to_file", dump_to_stderr)
540     def test_python2_print_function(self) -> None:
541         source, expected = read_data("python2_print_function")
542         mode = black.FileMode(target_versions={TargetVersion.PY27})
543         actual = fs(source, mode=mode)
544         self.assertFormatEqual(expected, actual)
545         black.assert_equivalent(source, actual)
546         black.assert_stable(source, actual, mode)
547
548     @patch("black.dump_to_file", dump_to_stderr)
549     def test_python2_unicode_literals(self) -> None:
550         source, expected = read_data("python2_unicode_literals")
551         actual = fs(source)
552         self.assertFormatEqual(expected, actual)
553         black.assert_equivalent(source, actual)
554         black.assert_stable(source, actual, black.FileMode())
555
556     @patch("black.dump_to_file", dump_to_stderr)
557     def test_stub(self) -> None:
558         mode = black.FileMode(is_pyi=True)
559         source, expected = read_data("stub.pyi")
560         actual = fs(source, mode=mode)
561         self.assertFormatEqual(expected, actual)
562         black.assert_stable(source, actual, mode)
563
564     @patch("black.dump_to_file", dump_to_stderr)
565     def test_async_as_identifier(self) -> None:
566         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
567         source, expected = read_data("async_as_identifier")
568         actual = fs(source)
569         self.assertFormatEqual(expected, actual)
570         major, minor = sys.version_info[:2]
571         if major < 3 or (major <= 3 and minor < 7):
572             black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, black.FileMode())
574         # ensure black can parse this when the target is 3.6
575         self.invokeBlack([str(source_path), "--target-version", "py36"])
576         # but not on 3.7, because async/await is no longer an identifier
577         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
578
579     @patch("black.dump_to_file", dump_to_stderr)
580     def test_python37(self) -> None:
581         source_path = (THIS_DIR / "data" / "python37.py").resolve()
582         source, expected = read_data("python37")
583         actual = fs(source)
584         self.assertFormatEqual(expected, actual)
585         major, minor = sys.version_info[:2]
586         if major > 3 or (major == 3 and minor >= 7):
587             black.assert_equivalent(source, actual)
588         black.assert_stable(source, actual, black.FileMode())
589         # ensure black can parse this when the target is 3.7
590         self.invokeBlack([str(source_path), "--target-version", "py37"])
591         # but not on 3.6, because we use async as a reserved keyword
592         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
593
594     @patch("black.dump_to_file", dump_to_stderr)
595     def test_fmtonoff(self) -> None:
596         source, expected = read_data("fmtonoff")
597         actual = fs(source)
598         self.assertFormatEqual(expected, actual)
599         black.assert_equivalent(source, actual)
600         black.assert_stable(source, actual, black.FileMode())
601
602     @patch("black.dump_to_file", dump_to_stderr)
603     def test_fmtonoff2(self) -> None:
604         source, expected = read_data("fmtonoff2")
605         actual = fs(source)
606         self.assertFormatEqual(expected, actual)
607         black.assert_equivalent(source, actual)
608         black.assert_stable(source, actual, black.FileMode())
609
610     @patch("black.dump_to_file", dump_to_stderr)
611     def test_remove_empty_parentheses_after_class(self) -> None:
612         source, expected = read_data("class_blank_parentheses")
613         actual = fs(source)
614         self.assertFormatEqual(expected, actual)
615         black.assert_equivalent(source, actual)
616         black.assert_stable(source, actual, black.FileMode())
617
618     @patch("black.dump_to_file", dump_to_stderr)
619     def test_new_line_between_class_and_code(self) -> None:
620         source, expected = read_data("class_methods_new_line")
621         actual = fs(source)
622         self.assertFormatEqual(expected, actual)
623         black.assert_equivalent(source, actual)
624         black.assert_stable(source, actual, black.FileMode())
625
626     @patch("black.dump_to_file", dump_to_stderr)
627     def test_bracket_match(self) -> None:
628         source, expected = read_data("bracketmatch")
629         actual = fs(source)
630         self.assertFormatEqual(expected, actual)
631         black.assert_equivalent(source, actual)
632         black.assert_stable(source, actual, black.FileMode())
633
634     @patch("black.dump_to_file", dump_to_stderr)
635     def test_tuple_assign(self) -> None:
636         source, expected = read_data("tupleassign")
637         actual = fs(source)
638         self.assertFormatEqual(expected, actual)
639         black.assert_equivalent(source, actual)
640         black.assert_stable(source, actual, black.FileMode())
641
642     @patch("black.dump_to_file", dump_to_stderr)
643     def test_beginning_backslash(self) -> None:
644         source, expected = read_data("beginning_backslash")
645         actual = fs(source)
646         self.assertFormatEqual(expected, actual)
647         black.assert_equivalent(source, actual)
648         black.assert_stable(source, actual, black.FileMode())
649
650     def test_tab_comment_indentation(self) -> None:
651         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
652         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
653         self.assertFormatEqual(contents_spc, fs(contents_spc))
654         self.assertFormatEqual(contents_spc, fs(contents_tab))
655
656         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
657         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
658         self.assertFormatEqual(contents_spc, fs(contents_spc))
659         self.assertFormatEqual(contents_spc, fs(contents_tab))
660
661         # mixed tabs and spaces (valid Python 2 code)
662         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
663         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
664         self.assertFormatEqual(contents_spc, fs(contents_spc))
665         self.assertFormatEqual(contents_spc, fs(contents_tab))
666
667         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
668         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
669         self.assertFormatEqual(contents_spc, fs(contents_spc))
670         self.assertFormatEqual(contents_spc, fs(contents_tab))
671
672     def test_report_verbose(self) -> None:
673         report = black.Report(verbose=True)
674         out_lines = []
675         err_lines = []
676
677         def out(msg: str, **kwargs: Any) -> None:
678             out_lines.append(msg)
679
680         def err(msg: str, **kwargs: Any) -> None:
681             err_lines.append(msg)
682
683         with patch("black.out", out), patch("black.err", err):
684             report.done(Path("f1"), black.Changed.NO)
685             self.assertEqual(len(out_lines), 1)
686             self.assertEqual(len(err_lines), 0)
687             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
688             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
689             self.assertEqual(report.return_code, 0)
690             report.done(Path("f2"), black.Changed.YES)
691             self.assertEqual(len(out_lines), 2)
692             self.assertEqual(len(err_lines), 0)
693             self.assertEqual(out_lines[-1], "reformatted f2")
694             self.assertEqual(
695                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
696             )
697             report.done(Path("f3"), black.Changed.CACHED)
698             self.assertEqual(len(out_lines), 3)
699             self.assertEqual(len(err_lines), 0)
700             self.assertEqual(
701                 out_lines[-1], "f3 wasn't modified on disk since last run."
702             )
703             self.assertEqual(
704                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
705             )
706             self.assertEqual(report.return_code, 0)
707             report.check = True
708             self.assertEqual(report.return_code, 1)
709             report.check = False
710             report.failed(Path("e1"), "boom")
711             self.assertEqual(len(out_lines), 3)
712             self.assertEqual(len(err_lines), 1)
713             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
714             self.assertEqual(
715                 unstyle(str(report)),
716                 "1 file reformatted, 2 files left unchanged, "
717                 "1 file failed to reformat.",
718             )
719             self.assertEqual(report.return_code, 123)
720             report.done(Path("f3"), black.Changed.YES)
721             self.assertEqual(len(out_lines), 4)
722             self.assertEqual(len(err_lines), 1)
723             self.assertEqual(out_lines[-1], "reformatted f3")
724             self.assertEqual(
725                 unstyle(str(report)),
726                 "2 files reformatted, 2 files left unchanged, "
727                 "1 file failed to reformat.",
728             )
729             self.assertEqual(report.return_code, 123)
730             report.failed(Path("e2"), "boom")
731             self.assertEqual(len(out_lines), 4)
732             self.assertEqual(len(err_lines), 2)
733             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
734             self.assertEqual(
735                 unstyle(str(report)),
736                 "2 files reformatted, 2 files left unchanged, "
737                 "2 files failed to reformat.",
738             )
739             self.assertEqual(report.return_code, 123)
740             report.path_ignored(Path("wat"), "no match")
741             self.assertEqual(len(out_lines), 5)
742             self.assertEqual(len(err_lines), 2)
743             self.assertEqual(out_lines[-1], "wat ignored: no match")
744             self.assertEqual(
745                 unstyle(str(report)),
746                 "2 files reformatted, 2 files left unchanged, "
747                 "2 files failed to reformat.",
748             )
749             self.assertEqual(report.return_code, 123)
750             report.done(Path("f4"), black.Changed.NO)
751             self.assertEqual(len(out_lines), 6)
752             self.assertEqual(len(err_lines), 2)
753             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
754             self.assertEqual(
755                 unstyle(str(report)),
756                 "2 files reformatted, 3 files left unchanged, "
757                 "2 files failed to reformat.",
758             )
759             self.assertEqual(report.return_code, 123)
760             report.check = True
761             self.assertEqual(
762                 unstyle(str(report)),
763                 "2 files would be reformatted, 3 files would be left unchanged, "
764                 "2 files would fail to reformat.",
765             )
766
767     def test_report_quiet(self) -> None:
768         report = black.Report(quiet=True)
769         out_lines = []
770         err_lines = []
771
772         def out(msg: str, **kwargs: Any) -> None:
773             out_lines.append(msg)
774
775         def err(msg: str, **kwargs: Any) -> None:
776             err_lines.append(msg)
777
778         with patch("black.out", out), patch("black.err", err):
779             report.done(Path("f1"), black.Changed.NO)
780             self.assertEqual(len(out_lines), 0)
781             self.assertEqual(len(err_lines), 0)
782             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
783             self.assertEqual(report.return_code, 0)
784             report.done(Path("f2"), black.Changed.YES)
785             self.assertEqual(len(out_lines), 0)
786             self.assertEqual(len(err_lines), 0)
787             self.assertEqual(
788                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
789             )
790             report.done(Path("f3"), black.Changed.CACHED)
791             self.assertEqual(len(out_lines), 0)
792             self.assertEqual(len(err_lines), 0)
793             self.assertEqual(
794                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
795             )
796             self.assertEqual(report.return_code, 0)
797             report.check = True
798             self.assertEqual(report.return_code, 1)
799             report.check = False
800             report.failed(Path("e1"), "boom")
801             self.assertEqual(len(out_lines), 0)
802             self.assertEqual(len(err_lines), 1)
803             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
804             self.assertEqual(
805                 unstyle(str(report)),
806                 "1 file reformatted, 2 files left unchanged, "
807                 "1 file failed to reformat.",
808             )
809             self.assertEqual(report.return_code, 123)
810             report.done(Path("f3"), black.Changed.YES)
811             self.assertEqual(len(out_lines), 0)
812             self.assertEqual(len(err_lines), 1)
813             self.assertEqual(
814                 unstyle(str(report)),
815                 "2 files reformatted, 2 files left unchanged, "
816                 "1 file failed to reformat.",
817             )
818             self.assertEqual(report.return_code, 123)
819             report.failed(Path("e2"), "boom")
820             self.assertEqual(len(out_lines), 0)
821             self.assertEqual(len(err_lines), 2)
822             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
823             self.assertEqual(
824                 unstyle(str(report)),
825                 "2 files reformatted, 2 files left unchanged, "
826                 "2 files failed to reformat.",
827             )
828             self.assertEqual(report.return_code, 123)
829             report.path_ignored(Path("wat"), "no match")
830             self.assertEqual(len(out_lines), 0)
831             self.assertEqual(len(err_lines), 2)
832             self.assertEqual(
833                 unstyle(str(report)),
834                 "2 files reformatted, 2 files left unchanged, "
835                 "2 files failed to reformat.",
836             )
837             self.assertEqual(report.return_code, 123)
838             report.done(Path("f4"), black.Changed.NO)
839             self.assertEqual(len(out_lines), 0)
840             self.assertEqual(len(err_lines), 2)
841             self.assertEqual(
842                 unstyle(str(report)),
843                 "2 files reformatted, 3 files left unchanged, "
844                 "2 files failed to reformat.",
845             )
846             self.assertEqual(report.return_code, 123)
847             report.check = True
848             self.assertEqual(
849                 unstyle(str(report)),
850                 "2 files would be reformatted, 3 files would be left unchanged, "
851                 "2 files would fail to reformat.",
852             )
853
854     def test_report_normal(self) -> None:
855         report = black.Report()
856         out_lines = []
857         err_lines = []
858
859         def out(msg: str, **kwargs: Any) -> None:
860             out_lines.append(msg)
861
862         def err(msg: str, **kwargs: Any) -> None:
863             err_lines.append(msg)
864
865         with patch("black.out", out), patch("black.err", err):
866             report.done(Path("f1"), black.Changed.NO)
867             self.assertEqual(len(out_lines), 0)
868             self.assertEqual(len(err_lines), 0)
869             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
870             self.assertEqual(report.return_code, 0)
871             report.done(Path("f2"), black.Changed.YES)
872             self.assertEqual(len(out_lines), 1)
873             self.assertEqual(len(err_lines), 0)
874             self.assertEqual(out_lines[-1], "reformatted f2")
875             self.assertEqual(
876                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
877             )
878             report.done(Path("f3"), black.Changed.CACHED)
879             self.assertEqual(len(out_lines), 1)
880             self.assertEqual(len(err_lines), 0)
881             self.assertEqual(out_lines[-1], "reformatted f2")
882             self.assertEqual(
883                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
884             )
885             self.assertEqual(report.return_code, 0)
886             report.check = True
887             self.assertEqual(report.return_code, 1)
888             report.check = False
889             report.failed(Path("e1"), "boom")
890             self.assertEqual(len(out_lines), 1)
891             self.assertEqual(len(err_lines), 1)
892             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
893             self.assertEqual(
894                 unstyle(str(report)),
895                 "1 file reformatted, 2 files left unchanged, "
896                 "1 file failed to reformat.",
897             )
898             self.assertEqual(report.return_code, 123)
899             report.done(Path("f3"), black.Changed.YES)
900             self.assertEqual(len(out_lines), 2)
901             self.assertEqual(len(err_lines), 1)
902             self.assertEqual(out_lines[-1], "reformatted f3")
903             self.assertEqual(
904                 unstyle(str(report)),
905                 "2 files reformatted, 2 files left unchanged, "
906                 "1 file failed to reformat.",
907             )
908             self.assertEqual(report.return_code, 123)
909             report.failed(Path("e2"), "boom")
910             self.assertEqual(len(out_lines), 2)
911             self.assertEqual(len(err_lines), 2)
912             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
913             self.assertEqual(
914                 unstyle(str(report)),
915                 "2 files reformatted, 2 files left unchanged, "
916                 "2 files failed to reformat.",
917             )
918             self.assertEqual(report.return_code, 123)
919             report.path_ignored(Path("wat"), "no match")
920             self.assertEqual(len(out_lines), 2)
921             self.assertEqual(len(err_lines), 2)
922             self.assertEqual(
923                 unstyle(str(report)),
924                 "2 files reformatted, 2 files left unchanged, "
925                 "2 files failed to reformat.",
926             )
927             self.assertEqual(report.return_code, 123)
928             report.done(Path("f4"), black.Changed.NO)
929             self.assertEqual(len(out_lines), 2)
930             self.assertEqual(len(err_lines), 2)
931             self.assertEqual(
932                 unstyle(str(report)),
933                 "2 files reformatted, 3 files left unchanged, "
934                 "2 files failed to reformat.",
935             )
936             self.assertEqual(report.return_code, 123)
937             report.check = True
938             self.assertEqual(
939                 unstyle(str(report)),
940                 "2 files would be reformatted, 3 files would be left unchanged, "
941                 "2 files would fail to reformat.",
942             )
943
944     def test_lib2to3_parse(self) -> None:
945         with self.assertRaises(black.InvalidInput):
946             black.lib2to3_parse("invalid syntax")
947
948         straddling = "x + y"
949         black.lib2to3_parse(straddling)
950         black.lib2to3_parse(straddling, {TargetVersion.PY27})
951         black.lib2to3_parse(straddling, {TargetVersion.PY36})
952         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
953
954         py2_only = "print x"
955         black.lib2to3_parse(py2_only)
956         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
957         with self.assertRaises(black.InvalidInput):
958             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
959         with self.assertRaises(black.InvalidInput):
960             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
961
962         py3_only = "exec(x, end=y)"
963         black.lib2to3_parse(py3_only)
964         with self.assertRaises(black.InvalidInput):
965             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
966         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
967         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
968
969     def test_get_features_used(self) -> None:
970         node = black.lib2to3_parse("def f(*, arg): ...\n")
971         self.assertEqual(black.get_features_used(node), set())
972         node = black.lib2to3_parse("def f(*, arg,): ...\n")
973         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
974         node = black.lib2to3_parse("f(*arg,)\n")
975         self.assertEqual(
976             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
977         )
978         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
979         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
980         node = black.lib2to3_parse("123_456\n")
981         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
982         node = black.lib2to3_parse("123456\n")
983         self.assertEqual(black.get_features_used(node), set())
984         source, expected = read_data("function")
985         node = black.lib2to3_parse(source)
986         expected_features = {
987             Feature.TRAILING_COMMA_IN_CALL,
988             Feature.TRAILING_COMMA_IN_DEF,
989             Feature.F_STRINGS,
990         }
991         self.assertEqual(black.get_features_used(node), expected_features)
992         node = black.lib2to3_parse(expected)
993         self.assertEqual(black.get_features_used(node), expected_features)
994         source, expected = read_data("expression")
995         node = black.lib2to3_parse(source)
996         self.assertEqual(black.get_features_used(node), set())
997         node = black.lib2to3_parse(expected)
998         self.assertEqual(black.get_features_used(node), set())
999
1000     def test_get_future_imports(self) -> None:
1001         node = black.lib2to3_parse("\n")
1002         self.assertEqual(set(), black.get_future_imports(node))
1003         node = black.lib2to3_parse("from __future__ import black\n")
1004         self.assertEqual({"black"}, black.get_future_imports(node))
1005         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1006         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1007         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1008         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1009         node = black.lib2to3_parse(
1010             "from __future__ import multiple\nfrom __future__ import imports\n"
1011         )
1012         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1013         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1014         self.assertEqual({"black"}, black.get_future_imports(node))
1015         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1016         self.assertEqual({"black"}, black.get_future_imports(node))
1017         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1018         self.assertEqual(set(), black.get_future_imports(node))
1019         node = black.lib2to3_parse("from some.module import black\n")
1020         self.assertEqual(set(), black.get_future_imports(node))
1021         node = black.lib2to3_parse(
1022             "from __future__ import unicode_literals as _unicode_literals"
1023         )
1024         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1025         node = black.lib2to3_parse(
1026             "from __future__ import unicode_literals as _lol, print"
1027         )
1028         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1029
1030     def test_debug_visitor(self) -> None:
1031         source, _ = read_data("debug_visitor.py")
1032         expected, _ = read_data("debug_visitor.out")
1033         out_lines = []
1034         err_lines = []
1035
1036         def out(msg: str, **kwargs: Any) -> None:
1037             out_lines.append(msg)
1038
1039         def err(msg: str, **kwargs: Any) -> None:
1040             err_lines.append(msg)
1041
1042         with patch("black.out", out), patch("black.err", err):
1043             black.DebugVisitor.show(source)
1044         actual = "\n".join(out_lines) + "\n"
1045         log_name = ""
1046         if expected != actual:
1047             log_name = black.dump_to_file(*out_lines)
1048         self.assertEqual(
1049             expected,
1050             actual,
1051             f"AST print out is different. Actual version dumped to {log_name}",
1052         )
1053
1054     def test_format_file_contents(self) -> None:
1055         empty = ""
1056         mode = black.FileMode()
1057         with self.assertRaises(black.NothingChanged):
1058             black.format_file_contents(empty, mode=mode, fast=False)
1059         just_nl = "\n"
1060         with self.assertRaises(black.NothingChanged):
1061             black.format_file_contents(just_nl, mode=mode, fast=False)
1062         same = "l = [1, 2, 3]\n"
1063         with self.assertRaises(black.NothingChanged):
1064             black.format_file_contents(same, mode=mode, fast=False)
1065         different = "l = [1,2,3]"
1066         expected = same
1067         actual = black.format_file_contents(different, mode=mode, fast=False)
1068         self.assertEqual(expected, actual)
1069         invalid = "return if you can"
1070         with self.assertRaises(black.InvalidInput) as e:
1071             black.format_file_contents(invalid, mode=mode, fast=False)
1072         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1073
1074     def test_endmarker(self) -> None:
1075         n = black.lib2to3_parse("\n")
1076         self.assertEqual(n.type, black.syms.file_input)
1077         self.assertEqual(len(n.children), 1)
1078         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1079
1080     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1081     def test_assertFormatEqual(self) -> None:
1082         out_lines = []
1083         err_lines = []
1084
1085         def out(msg: str, **kwargs: Any) -> None:
1086             out_lines.append(msg)
1087
1088         def err(msg: str, **kwargs: Any) -> None:
1089             err_lines.append(msg)
1090
1091         with patch("black.out", out), patch("black.err", err):
1092             with self.assertRaises(AssertionError):
1093                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1094
1095         out_str = "".join(out_lines)
1096         self.assertTrue("Expected tree:" in out_str)
1097         self.assertTrue("Actual tree:" in out_str)
1098         self.assertEqual("".join(err_lines), "")
1099
1100     def test_cache_broken_file(self) -> None:
1101         mode = black.FileMode()
1102         with cache_dir() as workspace:
1103             cache_file = black.get_cache_file(mode)
1104             with cache_file.open("w") as fobj:
1105                 fobj.write("this is not a pickle")
1106             self.assertEqual(black.read_cache(mode), {})
1107             src = (workspace / "test.py").resolve()
1108             with src.open("w") as fobj:
1109                 fobj.write("print('hello')")
1110             self.invokeBlack([str(src)])
1111             cache = black.read_cache(mode)
1112             self.assertIn(src, cache)
1113
1114     def test_cache_single_file_already_cached(self) -> None:
1115         mode = black.FileMode()
1116         with cache_dir() as workspace:
1117             src = (workspace / "test.py").resolve()
1118             with src.open("w") as fobj:
1119                 fobj.write("print('hello')")
1120             black.write_cache({}, [src], mode)
1121             self.invokeBlack([str(src)])
1122             with src.open("r") as fobj:
1123                 self.assertEqual(fobj.read(), "print('hello')")
1124
1125     @event_loop(close=False)
1126     def test_cache_multiple_files(self) -> None:
1127         mode = black.FileMode()
1128         with cache_dir() as workspace, patch(
1129             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1130         ):
1131             one = (workspace / "one.py").resolve()
1132             with one.open("w") as fobj:
1133                 fobj.write("print('hello')")
1134             two = (workspace / "two.py").resolve()
1135             with two.open("w") as fobj:
1136                 fobj.write("print('hello')")
1137             black.write_cache({}, [one], mode)
1138             self.invokeBlack([str(workspace)])
1139             with one.open("r") as fobj:
1140                 self.assertEqual(fobj.read(), "print('hello')")
1141             with two.open("r") as fobj:
1142                 self.assertEqual(fobj.read(), 'print("hello")\n')
1143             cache = black.read_cache(mode)
1144             self.assertIn(one, cache)
1145             self.assertIn(two, cache)
1146
1147     def test_no_cache_when_writeback_diff(self) -> None:
1148         mode = black.FileMode()
1149         with cache_dir() as workspace:
1150             src = (workspace / "test.py").resolve()
1151             with src.open("w") as fobj:
1152                 fobj.write("print('hello')")
1153             self.invokeBlack([str(src), "--diff"])
1154             cache_file = black.get_cache_file(mode)
1155             self.assertFalse(cache_file.exists())
1156
1157     def test_no_cache_when_stdin(self) -> None:
1158         mode = black.FileMode()
1159         with cache_dir():
1160             result = CliRunner().invoke(
1161                 black.main, ["-"], input=BytesIO(b"print('hello')")
1162             )
1163             self.assertEqual(result.exit_code, 0)
1164             cache_file = black.get_cache_file(mode)
1165             self.assertFalse(cache_file.exists())
1166
1167     def test_read_cache_no_cachefile(self) -> None:
1168         mode = black.FileMode()
1169         with cache_dir():
1170             self.assertEqual(black.read_cache(mode), {})
1171
1172     def test_write_cache_read_cache(self) -> None:
1173         mode = black.FileMode()
1174         with cache_dir() as workspace:
1175             src = (workspace / "test.py").resolve()
1176             src.touch()
1177             black.write_cache({}, [src], mode)
1178             cache = black.read_cache(mode)
1179             self.assertIn(src, cache)
1180             self.assertEqual(cache[src], black.get_cache_info(src))
1181
1182     def test_filter_cached(self) -> None:
1183         with TemporaryDirectory() as workspace:
1184             path = Path(workspace)
1185             uncached = (path / "uncached").resolve()
1186             cached = (path / "cached").resolve()
1187             cached_but_changed = (path / "changed").resolve()
1188             uncached.touch()
1189             cached.touch()
1190             cached_but_changed.touch()
1191             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1192             todo, done = black.filter_cached(
1193                 cache, {uncached, cached, cached_but_changed}
1194             )
1195             self.assertEqual(todo, {uncached, cached_but_changed})
1196             self.assertEqual(done, {cached})
1197
1198     def test_write_cache_creates_directory_if_needed(self) -> None:
1199         mode = black.FileMode()
1200         with cache_dir(exists=False) as workspace:
1201             self.assertFalse(workspace.exists())
1202             black.write_cache({}, [], mode)
1203             self.assertTrue(workspace.exists())
1204
1205     @event_loop(close=False)
1206     def test_failed_formatting_does_not_get_cached(self) -> None:
1207         mode = black.FileMode()
1208         with cache_dir() as workspace, patch(
1209             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1210         ):
1211             failing = (workspace / "failing.py").resolve()
1212             with failing.open("w") as fobj:
1213                 fobj.write("not actually python")
1214             clean = (workspace / "clean.py").resolve()
1215             with clean.open("w") as fobj:
1216                 fobj.write('print("hello")\n')
1217             self.invokeBlack([str(workspace)], exit_code=123)
1218             cache = black.read_cache(mode)
1219             self.assertNotIn(failing, cache)
1220             self.assertIn(clean, cache)
1221
1222     def test_write_cache_write_fail(self) -> None:
1223         mode = black.FileMode()
1224         with cache_dir(), patch.object(Path, "open") as mock:
1225             mock.side_effect = OSError
1226             black.write_cache({}, [], mode)
1227
1228     @event_loop(close=False)
1229     def test_check_diff_use_together(self) -> None:
1230         with cache_dir():
1231             # Files which will be reformatted.
1232             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1233             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1234             # Files which will not be reformatted.
1235             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1236             self.invokeBlack([str(src2), "--diff", "--check"])
1237             # Multi file command.
1238             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1239
1240     def test_no_files(self) -> None:
1241         with cache_dir():
1242             # Without an argument, black exits with error code 0.
1243             self.invokeBlack([])
1244
1245     def test_broken_symlink(self) -> None:
1246         with cache_dir() as workspace:
1247             symlink = workspace / "broken_link.py"
1248             try:
1249                 symlink.symlink_to("nonexistent.py")
1250             except OSError as e:
1251                 self.skipTest(f"Can't create symlinks: {e}")
1252             self.invokeBlack([str(workspace.resolve())])
1253
1254     def test_read_cache_line_lengths(self) -> None:
1255         mode = black.FileMode()
1256         short_mode = black.FileMode(line_length=1)
1257         with cache_dir() as workspace:
1258             path = (workspace / "file.py").resolve()
1259             path.touch()
1260             black.write_cache({}, [path], mode)
1261             one = black.read_cache(mode)
1262             self.assertIn(path, one)
1263             two = black.read_cache(short_mode)
1264             self.assertNotIn(path, two)
1265
1266     def test_single_file_force_pyi(self) -> None:
1267         reg_mode = black.FileMode()
1268         pyi_mode = black.FileMode(is_pyi=True)
1269         contents, expected = read_data("force_pyi")
1270         with cache_dir() as workspace:
1271             path = (workspace / "file.py").resolve()
1272             with open(path, "w") as fh:
1273                 fh.write(contents)
1274             self.invokeBlack([str(path), "--pyi"])
1275             with open(path, "r") as fh:
1276                 actual = fh.read()
1277             # verify cache with --pyi is separate
1278             pyi_cache = black.read_cache(pyi_mode)
1279             self.assertIn(path, pyi_cache)
1280             normal_cache = black.read_cache(reg_mode)
1281             self.assertNotIn(path, normal_cache)
1282         self.assertEqual(actual, expected)
1283
1284     @event_loop(close=False)
1285     def test_multi_file_force_pyi(self) -> None:
1286         reg_mode = black.FileMode()
1287         pyi_mode = black.FileMode(is_pyi=True)
1288         contents, expected = read_data("force_pyi")
1289         with cache_dir() as workspace:
1290             paths = [
1291                 (workspace / "file1.py").resolve(),
1292                 (workspace / "file2.py").resolve(),
1293             ]
1294             for path in paths:
1295                 with open(path, "w") as fh:
1296                     fh.write(contents)
1297             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1298             for path in paths:
1299                 with open(path, "r") as fh:
1300                     actual = fh.read()
1301                 self.assertEqual(actual, expected)
1302             # verify cache with --pyi is separate
1303             pyi_cache = black.read_cache(pyi_mode)
1304             normal_cache = black.read_cache(reg_mode)
1305             for path in paths:
1306                 self.assertIn(path, pyi_cache)
1307                 self.assertNotIn(path, normal_cache)
1308
1309     def test_pipe_force_pyi(self) -> None:
1310         source, expected = read_data("force_pyi")
1311         result = CliRunner().invoke(
1312             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1313         )
1314         self.assertEqual(result.exit_code, 0)
1315         actual = result.output
1316         self.assertFormatEqual(actual, expected)
1317
1318     def test_single_file_force_py36(self) -> None:
1319         reg_mode = black.FileMode()
1320         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1321         source, expected = read_data("force_py36")
1322         with cache_dir() as workspace:
1323             path = (workspace / "file.py").resolve()
1324             with open(path, "w") as fh:
1325                 fh.write(source)
1326             self.invokeBlack([str(path), *PY36_ARGS])
1327             with open(path, "r") as fh:
1328                 actual = fh.read()
1329             # verify cache with --target-version is separate
1330             py36_cache = black.read_cache(py36_mode)
1331             self.assertIn(path, py36_cache)
1332             normal_cache = black.read_cache(reg_mode)
1333             self.assertNotIn(path, normal_cache)
1334         self.assertEqual(actual, expected)
1335
1336     @event_loop(close=False)
1337     def test_multi_file_force_py36(self) -> None:
1338         reg_mode = black.FileMode()
1339         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1340         source, expected = read_data("force_py36")
1341         with cache_dir() as workspace:
1342             paths = [
1343                 (workspace / "file1.py").resolve(),
1344                 (workspace / "file2.py").resolve(),
1345             ]
1346             for path in paths:
1347                 with open(path, "w") as fh:
1348                     fh.write(source)
1349             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1350             for path in paths:
1351                 with open(path, "r") as fh:
1352                     actual = fh.read()
1353                 self.assertEqual(actual, expected)
1354             # verify cache with --target-version is separate
1355             pyi_cache = black.read_cache(py36_mode)
1356             normal_cache = black.read_cache(reg_mode)
1357             for path in paths:
1358                 self.assertIn(path, pyi_cache)
1359                 self.assertNotIn(path, normal_cache)
1360
1361     def test_pipe_force_py36(self) -> None:
1362         source, expected = read_data("force_py36")
1363         result = CliRunner().invoke(
1364             black.main,
1365             ["-", "-q", "--target-version=py36"],
1366             input=BytesIO(source.encode("utf8")),
1367         )
1368         self.assertEqual(result.exit_code, 0)
1369         actual = result.output
1370         self.assertFormatEqual(actual, expected)
1371
1372     def test_include_exclude(self) -> None:
1373         path = THIS_DIR / "data" / "include_exclude_tests"
1374         include = re.compile(r"\.pyi?$")
1375         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1376         report = black.Report()
1377         sources: List[Path] = []
1378         expected = [
1379             Path(path / "b/dont_exclude/a.py"),
1380             Path(path / "b/dont_exclude/a.pyi"),
1381         ]
1382         this_abs = THIS_DIR.resolve()
1383         sources.extend(
1384             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1385         )
1386         self.assertEqual(sorted(expected), sorted(sources))
1387
1388     def test_empty_include(self) -> None:
1389         path = THIS_DIR / "data" / "include_exclude_tests"
1390         report = black.Report()
1391         empty = re.compile(r"")
1392         sources: List[Path] = []
1393         expected = [
1394             Path(path / "b/exclude/a.pie"),
1395             Path(path / "b/exclude/a.py"),
1396             Path(path / "b/exclude/a.pyi"),
1397             Path(path / "b/dont_exclude/a.pie"),
1398             Path(path / "b/dont_exclude/a.py"),
1399             Path(path / "b/dont_exclude/a.pyi"),
1400             Path(path / "b/.definitely_exclude/a.pie"),
1401             Path(path / "b/.definitely_exclude/a.py"),
1402             Path(path / "b/.definitely_exclude/a.pyi"),
1403         ]
1404         this_abs = THIS_DIR.resolve()
1405         sources.extend(
1406             black.gen_python_files_in_dir(
1407                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1408             )
1409         )
1410         self.assertEqual(sorted(expected), sorted(sources))
1411
1412     def test_empty_exclude(self) -> None:
1413         path = THIS_DIR / "data" / "include_exclude_tests"
1414         report = black.Report()
1415         empty = re.compile(r"")
1416         sources: List[Path] = []
1417         expected = [
1418             Path(path / "b/dont_exclude/a.py"),
1419             Path(path / "b/dont_exclude/a.pyi"),
1420             Path(path / "b/exclude/a.py"),
1421             Path(path / "b/exclude/a.pyi"),
1422             Path(path / "b/.definitely_exclude/a.py"),
1423             Path(path / "b/.definitely_exclude/a.pyi"),
1424         ]
1425         this_abs = THIS_DIR.resolve()
1426         sources.extend(
1427             black.gen_python_files_in_dir(
1428                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1429             )
1430         )
1431         self.assertEqual(sorted(expected), sorted(sources))
1432
1433     def test_invalid_include_exclude(self) -> None:
1434         for option in ["--include", "--exclude"]:
1435             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1436
1437     def test_preserves_line_endings(self) -> None:
1438         with TemporaryDirectory() as workspace:
1439             test_file = Path(workspace) / "test.py"
1440             for nl in ["\n", "\r\n"]:
1441                 contents = nl.join(["def f(  ):", "    pass"])
1442                 test_file.write_bytes(contents.encode())
1443                 ff(test_file, write_back=black.WriteBack.YES)
1444                 updated_contents: bytes = test_file.read_bytes()
1445                 self.assertIn(nl.encode(), updated_contents)
1446                 if nl == "\n":
1447                     self.assertNotIn(b"\r\n", updated_contents)
1448
1449     def test_preserves_line_endings_via_stdin(self) -> None:
1450         for nl in ["\n", "\r\n"]:
1451             contents = nl.join(["def f(  ):", "    pass"])
1452             runner = BlackRunner()
1453             result = runner.invoke(
1454                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1455             )
1456             self.assertEqual(result.exit_code, 0)
1457             output = runner.stdout_bytes
1458             self.assertIn(nl.encode("utf8"), output)
1459             if nl == "\n":
1460                 self.assertNotIn(b"\r\n", output)
1461
1462     def test_assert_equivalent_different_asts(self) -> None:
1463         with self.assertRaises(AssertionError):
1464             black.assert_equivalent("{}", "None")
1465
1466     def test_symlink_out_of_root_directory(self) -> None:
1467         path = MagicMock()
1468         root = THIS_DIR
1469         child = MagicMock()
1470         include = re.compile(black.DEFAULT_INCLUDES)
1471         exclude = re.compile(black.DEFAULT_EXCLUDES)
1472         report = black.Report()
1473         # `child` should behave like a symlink which resolved path is clearly
1474         # outside of the `root` directory.
1475         path.iterdir.return_value = [child]
1476         child.resolve.return_value = Path("/a/b/c")
1477         child.is_symlink.return_value = True
1478         try:
1479             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1480         except ValueError as ve:
1481             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1482         path.iterdir.assert_called_once()
1483         child.resolve.assert_called_once()
1484         child.is_symlink.assert_called_once()
1485         # `child` should behave like a strange file which resolved path is clearly
1486         # outside of the `root` directory.
1487         child.is_symlink.return_value = False
1488         with self.assertRaises(ValueError):
1489             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1490         path.iterdir.assert_called()
1491         self.assertEqual(path.iterdir.call_count, 2)
1492         child.resolve.assert_called()
1493         self.assertEqual(child.resolve.call_count, 2)
1494         child.is_symlink.assert_called()
1495         self.assertEqual(child.is_symlink.call_count, 2)
1496
1497     def test_shhh_click(self) -> None:
1498         try:
1499             from click import _unicodefun  # type: ignore
1500         except ModuleNotFoundError:
1501             self.skipTest("Incompatible Click version")
1502         if not hasattr(_unicodefun, "_verify_python3_env"):
1503             self.skipTest("Incompatible Click version")
1504         # First, let's see if Click is crashing with a preferred ASCII charset.
1505         with patch("locale.getpreferredencoding") as gpe:
1506             gpe.return_value = "ASCII"
1507             with self.assertRaises(RuntimeError):
1508                 _unicodefun._verify_python3_env()
1509         # Now, let's silence Click...
1510         black.patch_click()
1511         # ...and confirm it's silent.
1512         with patch("locale.getpreferredencoding") as gpe:
1513             gpe.return_value = "ASCII"
1514             try:
1515                 _unicodefun._verify_python3_env()
1516             except RuntimeError as re:
1517                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1518
1519     def test_root_logger_not_used_directly(self) -> None:
1520         def fail(*args: Any, **kwargs: Any) -> None:
1521             self.fail("Record created with root logger")
1522
1523         with patch.multiple(
1524             logging.root,
1525             debug=fail,
1526             info=fail,
1527             warning=fail,
1528             error=fail,
1529             critical=fail,
1530             log=fail,
1531         ):
1532             ff(THIS_FILE)
1533
1534     # TODO: remove these decorators once the below is released
1535     # https://github.com/aio-libs/aiohttp/pull/3727
1536     @skip_if_exception("ClientOSError")
1537     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1538     @async_test
1539     async def test_blackd_request_needs_formatting(self) -> None:
1540         app = blackd.make_app()
1541         async with TestClient(TestServer(app)) as client:
1542             response = await client.post("/", data=b"print('hello world')")
1543             self.assertEqual(response.status, 200)
1544             self.assertEqual(response.charset, "utf8")
1545             self.assertEqual(await response.read(), b'print("hello world")\n')
1546
1547     @skip_if_exception("ClientOSError")
1548     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1549     @async_test
1550     async def test_blackd_request_no_change(self) -> None:
1551         app = blackd.make_app()
1552         async with TestClient(TestServer(app)) as client:
1553             response = await client.post("/", data=b'print("hello world")\n')
1554             self.assertEqual(response.status, 204)
1555             self.assertEqual(await response.read(), b"")
1556
1557     @skip_if_exception("ClientOSError")
1558     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1559     @async_test
1560     async def test_blackd_request_syntax_error(self) -> None:
1561         app = blackd.make_app()
1562         async with TestClient(TestServer(app)) as client:
1563             response = await client.post("/", data=b"what even ( is")
1564             self.assertEqual(response.status, 400)
1565             content = await response.text()
1566             self.assertTrue(
1567                 content.startswith("Cannot parse"),
1568                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1569             )
1570
1571     @skip_if_exception("ClientOSError")
1572     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1573     @async_test
1574     async def test_blackd_unsupported_version(self) -> None:
1575         app = blackd.make_app()
1576         async with TestClient(TestServer(app)) as client:
1577             response = await client.post(
1578                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1579             )
1580             self.assertEqual(response.status, 501)
1581
1582     @skip_if_exception("ClientOSError")
1583     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1584     @async_test
1585     async def test_blackd_supported_version(self) -> None:
1586         app = blackd.make_app()
1587         async with TestClient(TestServer(app)) as client:
1588             response = await client.post(
1589                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1590             )
1591             self.assertEqual(response.status, 200)
1592
1593     @skip_if_exception("ClientOSError")
1594     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1595     @async_test
1596     async def test_blackd_invalid_python_variant(self) -> None:
1597         app = blackd.make_app()
1598         async with TestClient(TestServer(app)) as client:
1599
1600             async def check(header_value: str, expected_status: int = 400) -> None:
1601                 response = await client.post(
1602                     "/",
1603                     data=b"what",
1604                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1605                 )
1606                 self.assertEqual(response.status, expected_status)
1607
1608             await check("lol")
1609             await check("ruby3.5")
1610             await check("pyi3.6")
1611             await check("py1.5")
1612             await check("2.8")
1613             await check("py2.8")
1614             await check("3.0")
1615             await check("pypy3.0")
1616             await check("jython3.4")
1617
1618     @skip_if_exception("ClientOSError")
1619     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1620     @async_test
1621     async def test_blackd_pyi(self) -> None:
1622         app = blackd.make_app()
1623         async with TestClient(TestServer(app)) as client:
1624             source, expected = read_data("stub.pyi")
1625             response = await client.post(
1626                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1627             )
1628             self.assertEqual(response.status, 200)
1629             self.assertEqual(await response.text(), expected)
1630
1631     @skip_if_exception("ClientOSError")
1632     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1633     @async_test
1634     async def test_blackd_python_variant(self) -> None:
1635         app = blackd.make_app()
1636         code = (
1637             "def f(\n"
1638             "    and_has_a_bunch_of,\n"
1639             "    very_long_arguments_too,\n"
1640             "    and_lots_of_them_as_well_lol,\n"
1641             "    **and_very_long_keyword_arguments\n"
1642             "):\n"
1643             "    pass\n"
1644         )
1645         async with TestClient(TestServer(app)) as client:
1646
1647             async def check(header_value: str, expected_status: int) -> None:
1648                 response = await client.post(
1649                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1650                 )
1651                 self.assertEqual(response.status, expected_status)
1652
1653             await check("3.6", 200)
1654             await check("py3.6", 200)
1655             await check("3.6,3.7", 200)
1656             await check("3.6,py3.7", 200)
1657
1658             await check("2", 204)
1659             await check("2.7", 204)
1660             await check("py2.7", 204)
1661             await check("3.4", 204)
1662             await check("py3.4", 204)
1663
1664     @skip_if_exception("ClientOSError")
1665     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1666     @async_test
1667     async def test_blackd_line_length(self) -> None:
1668         app = blackd.make_app()
1669         async with TestClient(TestServer(app)) as client:
1670             response = await client.post(
1671                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1672             )
1673             self.assertEqual(response.status, 200)
1674
1675     @skip_if_exception("ClientOSError")
1676     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1677     @async_test
1678     async def test_blackd_invalid_line_length(self) -> None:
1679         app = blackd.make_app()
1680         async with TestClient(TestServer(app)) as client:
1681             response = await client.post(
1682                 "/",
1683                 data=b'print("hello")\n',
1684                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1685             )
1686             self.assertEqual(response.status, 400)
1687
1688     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1689     def test_blackd_main(self) -> None:
1690         with patch("blackd.web.run_app"):
1691             result = CliRunner().invoke(blackd.main, [])
1692             if result.exception is not None:
1693                 raise result.exception
1694             self.assertEqual(result.exit_code, 0)
1695
1696
1697 if __name__ == "__main__":
1698     unittest.main(module="test_black")