]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Set correct return statement for `is_type_comment` function (#929)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import regex as re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import Any, BinaryIO, Generator, List, Tuple, Iterator, TypeVar
14 import unittest
15 from unittest.mock import patch, MagicMock
16
17 from click import unstyle
18 from click.testing import CliRunner
19
20 import black
21 from black import Feature, TargetVersion
22
23 try:
24     import blackd
25     from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
26     from aiohttp import web
27 except ImportError:
28     has_blackd_deps = False
29 else:
30     has_blackd_deps = True
31
32 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
33 fs = partial(black.format_str, mode=black.FileMode())
34 THIS_FILE = Path(__file__)
35 THIS_DIR = THIS_FILE.parent
36 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
37 PY36_ARGS = [
38     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
39 ]
40 T = TypeVar("T")
41 R = TypeVar("R")
42
43
44 def dump_to_stderr(*output: str) -> str:
45     return "\n" + "\n".join(output) + "\n"
46
47
48 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
49     """read_data('test_name') -> 'input', 'output'"""
50     if not name.endswith((".py", ".pyi", ".out", ".diff")):
51         name += ".py"
52     _input: List[str] = []
53     _output: List[str] = []
54     base_dir = THIS_DIR / "data" if data else THIS_DIR
55     with open(base_dir / name, "r", encoding="utf8") as test:
56         lines = test.readlines()
57     result = _input
58     for line in lines:
59         line = line.replace(EMPTY_LINE, "")
60         if line.rstrip() == "# output":
61             result = _output
62             continue
63
64         result.append(line)
65     if _input and not _output:
66         # If there's no output marker, treat the entire file as already pre-formatted.
67         _output = _input[:]
68     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
69
70
71 @contextmanager
72 def cache_dir(exists: bool = True) -> Iterator[Path]:
73     with TemporaryDirectory() as workspace:
74         cache_dir = Path(workspace)
75         if not exists:
76             cache_dir = cache_dir / "new"
77         with patch("black.CACHE_DIR", cache_dir):
78             yield cache_dir
79
80
81 @contextmanager
82 def event_loop(close: bool) -> Iterator[None]:
83     policy = asyncio.get_event_loop_policy()
84     loop = policy.new_event_loop()
85     asyncio.set_event_loop(loop)
86     try:
87         yield
88
89     finally:
90         if close:
91             loop.close()
92
93
94 @contextmanager
95 def skip_if_exception(e: str) -> Iterator[None]:
96     try:
97         yield
98     except Exception as exc:
99         if exc.__class__.__name__ == e:
100             unittest.skip(f"Encountered expected exception {exc}, skipping")
101         else:
102             raise
103
104
105 class BlackRunner(CliRunner):
106     """Modify CliRunner so that stderr is not merged with stdout.
107
108     This is a hack that can be removed once we depend on Click 7.x"""
109
110     def __init__(self) -> None:
111         self.stderrbuf = BytesIO()
112         self.stdoutbuf = BytesIO()
113         self.stdout_bytes = b""
114         self.stderr_bytes = b""
115         super().__init__()
116
117     @contextmanager
118     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
119         with super().isolation(*args, **kwargs) as output:
120             try:
121                 hold_stderr = sys.stderr
122                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
123                 yield output
124             finally:
125                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
126                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
127                 sys.stderr = hold_stderr
128
129
130 class BlackTestCase(unittest.TestCase):
131     maxDiff = None
132
133     def assertFormatEqual(self, expected: str, actual: str) -> None:
134         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
135             bdv: black.DebugVisitor[Any]
136             black.out("Expected tree:", fg="green")
137             try:
138                 exp_node = black.lib2to3_parse(expected)
139                 bdv = black.DebugVisitor()
140                 list(bdv.visit(exp_node))
141             except Exception as ve:
142                 black.err(str(ve))
143             black.out("Actual tree:", fg="red")
144             try:
145                 exp_node = black.lib2to3_parse(actual)
146                 bdv = black.DebugVisitor()
147                 list(bdv.visit(exp_node))
148             except Exception as ve:
149                 black.err(str(ve))
150         self.assertEqual(expected, actual)
151
152     def invokeBlack(
153         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
154     ) -> None:
155         runner = BlackRunner()
156         if ignore_config:
157             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
158         result = runner.invoke(black.main, args)
159         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
160
161     @patch("black.dump_to_file", dump_to_stderr)
162     def checkSourceFile(self, name: str) -> None:
163         path = THIS_DIR.parent / name
164         source, expected = read_data(str(path), data=False)
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, black.FileMode())
169         self.assertFalse(ff(path))
170
171     @patch("black.dump_to_file", dump_to_stderr)
172     def test_empty(self) -> None:
173         source = expected = ""
174         actual = fs(source)
175         self.assertFormatEqual(expected, actual)
176         black.assert_equivalent(source, actual)
177         black.assert_stable(source, actual, black.FileMode())
178
179     def test_empty_ff(self) -> None:
180         expected = ""
181         tmp_file = Path(black.dump_to_file())
182         try:
183             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
184             with open(tmp_file, encoding="utf8") as f:
185                 actual = f.read()
186         finally:
187             os.unlink(tmp_file)
188         self.assertFormatEqual(expected, actual)
189
190     def test_self(self) -> None:
191         self.checkSourceFile("tests/test_black.py")
192
193     def test_black(self) -> None:
194         self.checkSourceFile("black.py")
195
196     def test_pygram(self) -> None:
197         self.checkSourceFile("blib2to3/pygram.py")
198
199     def test_pytree(self) -> None:
200         self.checkSourceFile("blib2to3/pytree.py")
201
202     def test_conv(self) -> None:
203         self.checkSourceFile("blib2to3/pgen2/conv.py")
204
205     def test_driver(self) -> None:
206         self.checkSourceFile("blib2to3/pgen2/driver.py")
207
208     def test_grammar(self) -> None:
209         self.checkSourceFile("blib2to3/pgen2/grammar.py")
210
211     def test_literals(self) -> None:
212         self.checkSourceFile("blib2to3/pgen2/literals.py")
213
214     def test_parse(self) -> None:
215         self.checkSourceFile("blib2to3/pgen2/parse.py")
216
217     def test_pgen(self) -> None:
218         self.checkSourceFile("blib2to3/pgen2/pgen.py")
219
220     def test_tokenize(self) -> None:
221         self.checkSourceFile("blib2to3/pgen2/tokenize.py")
222
223     def test_token(self) -> None:
224         self.checkSourceFile("blib2to3/pgen2/token.py")
225
226     def test_setup(self) -> None:
227         self.checkSourceFile("setup.py")
228
229     def test_piping(self) -> None:
230         source, expected = read_data("../black", data=False)
231         result = BlackRunner().invoke(
232             black.main,
233             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
234             input=BytesIO(source.encode("utf8")),
235         )
236         self.assertEqual(result.exit_code, 0)
237         self.assertFormatEqual(expected, result.output)
238         black.assert_equivalent(source, result.output)
239         black.assert_stable(source, result.output, black.FileMode())
240
241     def test_piping_diff(self) -> None:
242         diff_header = re.compile(
243             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
244             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
245         )
246         source, _ = read_data("expression.py")
247         expected, _ = read_data("expression.diff")
248         config = THIS_DIR / "data" / "empty_pyproject.toml"
249         args = [
250             "-",
251             "--fast",
252             f"--line-length={black.DEFAULT_LINE_LENGTH}",
253             "--diff",
254             f"--config={config}",
255         ]
256         result = BlackRunner().invoke(
257             black.main, args, input=BytesIO(source.encode("utf8"))
258         )
259         self.assertEqual(result.exit_code, 0)
260         actual = diff_header.sub("[Deterministic header]", result.output)
261         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
262         self.assertEqual(expected, actual)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_function(self) -> None:
266         source, expected = read_data("function")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, black.FileMode())
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_function2(self) -> None:
274         source, expected = read_data("function2")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, black.FileMode())
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_function_trailing_comma(self) -> None:
282         source, expected = read_data("function_trailing_comma")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, black.FileMode())
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_expression(self) -> None:
290         source, expected = read_data("expression")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, black.FileMode())
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_pep_572(self) -> None:
298         source, expected = read_data("pep_572")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_stable(source, actual, black.FileMode())
302         if sys.version_info >= (3, 8):
303             black.assert_equivalent(source, actual)
304
305     def test_pep_572_version_detection(self) -> None:
306         source, _ = read_data("pep_572")
307         root = black.lib2to3_parse(source)
308         features = black.get_features_used(root)
309         self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
310         versions = black.detect_target_versions(root)
311         self.assertIn(black.TargetVersion.PY38, versions)
312
313     def test_expression_ff(self) -> None:
314         source, expected = read_data("expression")
315         tmp_file = Path(black.dump_to_file(source))
316         try:
317             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
318             with open(tmp_file, encoding="utf8") as f:
319                 actual = f.read()
320         finally:
321             os.unlink(tmp_file)
322         self.assertFormatEqual(expected, actual)
323         with patch("black.dump_to_file", dump_to_stderr):
324             black.assert_equivalent(source, actual)
325             black.assert_stable(source, actual, black.FileMode())
326
327     def test_expression_diff(self) -> None:
328         source, _ = read_data("expression.py")
329         expected, _ = read_data("expression.diff")
330         tmp_file = Path(black.dump_to_file(source))
331         diff_header = re.compile(
332             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
333             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
334         )
335         try:
336             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
337             self.assertEqual(result.exit_code, 0)
338         finally:
339             os.unlink(tmp_file)
340         actual = result.output
341         actual = diff_header.sub("[Deterministic header]", actual)
342         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
343         if expected != actual:
344             dump = black.dump_to_file(actual)
345             msg = (
346                 f"Expected diff isn't equal to the actual. If you made changes "
347                 f"to expression.py and this is an anticipated difference, "
348                 f"overwrite tests/data/expression.diff with {dump}"
349             )
350             self.assertEqual(expected, actual, msg)
351
352     @patch("black.dump_to_file", dump_to_stderr)
353     def test_fstring(self) -> None:
354         source, expected = read_data("fstring")
355         actual = fs(source)
356         self.assertFormatEqual(expected, actual)
357         black.assert_equivalent(source, actual)
358         black.assert_stable(source, actual, black.FileMode())
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_pep_570(self) -> None:
362         source, expected = read_data("pep_570")
363         actual = fs(source)
364         self.assertFormatEqual(expected, actual)
365         black.assert_stable(source, actual, black.FileMode())
366         if sys.version_info >= (3, 8):
367             black.assert_equivalent(source, actual)
368
369     def test_detect_pos_only_arguments(self) -> None:
370         source, _ = read_data("pep_570")
371         root = black.lib2to3_parse(source)
372         features = black.get_features_used(root)
373         self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features)
374         versions = black.detect_target_versions(root)
375         self.assertIn(black.TargetVersion.PY38, versions)
376
377     @patch("black.dump_to_file", dump_to_stderr)
378     def test_string_quotes(self) -> None:
379         source, expected = read_data("string_quotes")
380         actual = fs(source)
381         self.assertFormatEqual(expected, actual)
382         black.assert_equivalent(source, actual)
383         black.assert_stable(source, actual, black.FileMode())
384         mode = black.FileMode(string_normalization=False)
385         not_normalized = fs(source, mode=mode)
386         self.assertFormatEqual(source, not_normalized)
387         black.assert_equivalent(source, not_normalized)
388         black.assert_stable(source, not_normalized, mode=mode)
389
390     @patch("black.dump_to_file", dump_to_stderr)
391     def test_slices(self) -> None:
392         source, expected = read_data("slices")
393         actual = fs(source)
394         self.assertFormatEqual(expected, actual)
395         black.assert_equivalent(source, actual)
396         black.assert_stable(source, actual, black.FileMode())
397
398     @patch("black.dump_to_file", dump_to_stderr)
399     def test_comments(self) -> None:
400         source, expected = read_data("comments")
401         actual = fs(source)
402         self.assertFormatEqual(expected, actual)
403         black.assert_equivalent(source, actual)
404         black.assert_stable(source, actual, black.FileMode())
405
406     @patch("black.dump_to_file", dump_to_stderr)
407     def test_comments2(self) -> None:
408         source, expected = read_data("comments2")
409         actual = fs(source)
410         self.assertFormatEqual(expected, actual)
411         black.assert_equivalent(source, actual)
412         black.assert_stable(source, actual, black.FileMode())
413
414     @patch("black.dump_to_file", dump_to_stderr)
415     def test_comments3(self) -> None:
416         source, expected = read_data("comments3")
417         actual = fs(source)
418         self.assertFormatEqual(expected, actual)
419         black.assert_equivalent(source, actual)
420         black.assert_stable(source, actual, black.FileMode())
421
422     @patch("black.dump_to_file", dump_to_stderr)
423     def test_comments4(self) -> None:
424         source, expected = read_data("comments4")
425         actual = fs(source)
426         self.assertFormatEqual(expected, actual)
427         black.assert_equivalent(source, actual)
428         black.assert_stable(source, actual, black.FileMode())
429
430     @patch("black.dump_to_file", dump_to_stderr)
431     def test_comments5(self) -> None:
432         source, expected = read_data("comments5")
433         actual = fs(source)
434         self.assertFormatEqual(expected, actual)
435         black.assert_equivalent(source, actual)
436         black.assert_stable(source, actual, black.FileMode())
437
438     @patch("black.dump_to_file", dump_to_stderr)
439     def test_comments6(self) -> None:
440         source, expected = read_data("comments6")
441         actual = fs(source)
442         self.assertFormatEqual(expected, actual)
443         black.assert_equivalent(source, actual)
444         black.assert_stable(source, actual, black.FileMode())
445
446     @patch("black.dump_to_file", dump_to_stderr)
447     def test_comments7(self) -> None:
448         source, expected = read_data("comments7")
449         actual = fs(source)
450         self.assertFormatEqual(expected, actual)
451         black.assert_equivalent(source, actual)
452         black.assert_stable(source, actual, black.FileMode())
453
454     @patch("black.dump_to_file", dump_to_stderr)
455     def test_comment_after_escaped_newline(self) -> None:
456         source, expected = read_data("comment_after_escaped_newline")
457         actual = fs(source)
458         self.assertFormatEqual(expected, actual)
459         black.assert_equivalent(source, actual)
460         black.assert_stable(source, actual, black.FileMode())
461
462     @patch("black.dump_to_file", dump_to_stderr)
463     def test_cantfit(self) -> None:
464         source, expected = read_data("cantfit")
465         actual = fs(source)
466         self.assertFormatEqual(expected, actual)
467         black.assert_equivalent(source, actual)
468         black.assert_stable(source, actual, black.FileMode())
469
470     @patch("black.dump_to_file", dump_to_stderr)
471     def test_import_spacing(self) -> None:
472         source, expected = read_data("import_spacing")
473         actual = fs(source)
474         self.assertFormatEqual(expected, actual)
475         black.assert_equivalent(source, actual)
476         black.assert_stable(source, actual, black.FileMode())
477
478     @patch("black.dump_to_file", dump_to_stderr)
479     def test_composition(self) -> None:
480         source, expected = read_data("composition")
481         actual = fs(source)
482         self.assertFormatEqual(expected, actual)
483         black.assert_equivalent(source, actual)
484         black.assert_stable(source, actual, black.FileMode())
485
486     @patch("black.dump_to_file", dump_to_stderr)
487     def test_empty_lines(self) -> None:
488         source, expected = read_data("empty_lines")
489         actual = fs(source)
490         self.assertFormatEqual(expected, actual)
491         black.assert_equivalent(source, actual)
492         black.assert_stable(source, actual, black.FileMode())
493
494     @patch("black.dump_to_file", dump_to_stderr)
495     def test_remove_parens(self) -> None:
496         source, expected = read_data("remove_parens")
497         actual = fs(source)
498         self.assertFormatEqual(expected, actual)
499         black.assert_equivalent(source, actual)
500         black.assert_stable(source, actual, black.FileMode())
501
502     @patch("black.dump_to_file", dump_to_stderr)
503     def test_string_prefixes(self) -> None:
504         source, expected = read_data("string_prefixes")
505         actual = fs(source)
506         self.assertFormatEqual(expected, actual)
507         black.assert_equivalent(source, actual)
508         black.assert_stable(source, actual, black.FileMode())
509
510     @patch("black.dump_to_file", dump_to_stderr)
511     def test_numeric_literals(self) -> None:
512         source, expected = read_data("numeric_literals")
513         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
514         actual = fs(source, mode=mode)
515         self.assertFormatEqual(expected, actual)
516         black.assert_equivalent(source, actual)
517         black.assert_stable(source, actual, mode)
518
519     @patch("black.dump_to_file", dump_to_stderr)
520     def test_numeric_literals_ignoring_underscores(self) -> None:
521         source, expected = read_data("numeric_literals_skip_underscores")
522         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
523         actual = fs(source, mode=mode)
524         self.assertFormatEqual(expected, actual)
525         black.assert_equivalent(source, actual)
526         black.assert_stable(source, actual, mode)
527
528     @patch("black.dump_to_file", dump_to_stderr)
529     def test_numeric_literals_py2(self) -> None:
530         source, expected = read_data("numeric_literals_py2")
531         actual = fs(source)
532         self.assertFormatEqual(expected, actual)
533         black.assert_stable(source, actual, black.FileMode())
534
535     @patch("black.dump_to_file", dump_to_stderr)
536     def test_python2(self) -> None:
537         source, expected = read_data("python2")
538         actual = fs(source)
539         self.assertFormatEqual(expected, actual)
540         black.assert_equivalent(source, actual)
541         black.assert_stable(source, actual, black.FileMode())
542
543     @patch("black.dump_to_file", dump_to_stderr)
544     def test_python2_print_function(self) -> None:
545         source, expected = read_data("python2_print_function")
546         mode = black.FileMode(target_versions={TargetVersion.PY27})
547         actual = fs(source, mode=mode)
548         self.assertFormatEqual(expected, actual)
549         black.assert_equivalent(source, actual)
550         black.assert_stable(source, actual, mode)
551
552     @patch("black.dump_to_file", dump_to_stderr)
553     def test_python2_unicode_literals(self) -> None:
554         source, expected = read_data("python2_unicode_literals")
555         actual = fs(source)
556         self.assertFormatEqual(expected, actual)
557         black.assert_equivalent(source, actual)
558         black.assert_stable(source, actual, black.FileMode())
559
560     @patch("black.dump_to_file", dump_to_stderr)
561     def test_stub(self) -> None:
562         mode = black.FileMode(is_pyi=True)
563         source, expected = read_data("stub.pyi")
564         actual = fs(source, mode=mode)
565         self.assertFormatEqual(expected, actual)
566         black.assert_stable(source, actual, mode)
567
568     @patch("black.dump_to_file", dump_to_stderr)
569     def test_async_as_identifier(self) -> None:
570         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
571         source, expected = read_data("async_as_identifier")
572         actual = fs(source)
573         self.assertFormatEqual(expected, actual)
574         major, minor = sys.version_info[:2]
575         if major < 3 or (major <= 3 and minor < 7):
576             black.assert_equivalent(source, actual)
577         black.assert_stable(source, actual, black.FileMode())
578         # ensure black can parse this when the target is 3.6
579         self.invokeBlack([str(source_path), "--target-version", "py36"])
580         # but not on 3.7, because async/await is no longer an identifier
581         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
582
583     @patch("black.dump_to_file", dump_to_stderr)
584     def test_python37(self) -> None:
585         source_path = (THIS_DIR / "data" / "python37.py").resolve()
586         source, expected = read_data("python37")
587         actual = fs(source)
588         self.assertFormatEqual(expected, actual)
589         major, minor = sys.version_info[:2]
590         if major > 3 or (major == 3 and minor >= 7):
591             black.assert_equivalent(source, actual)
592         black.assert_stable(source, actual, black.FileMode())
593         # ensure black can parse this when the target is 3.7
594         self.invokeBlack([str(source_path), "--target-version", "py37"])
595         # but not on 3.6, because we use async as a reserved keyword
596         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
597
598     @patch("black.dump_to_file", dump_to_stderr)
599     def test_fmtonoff(self) -> None:
600         source, expected = read_data("fmtonoff")
601         actual = fs(source)
602         self.assertFormatEqual(expected, actual)
603         black.assert_equivalent(source, actual)
604         black.assert_stable(source, actual, black.FileMode())
605
606     @patch("black.dump_to_file", dump_to_stderr)
607     def test_fmtonoff2(self) -> None:
608         source, expected = read_data("fmtonoff2")
609         actual = fs(source)
610         self.assertFormatEqual(expected, actual)
611         black.assert_equivalent(source, actual)
612         black.assert_stable(source, actual, black.FileMode())
613
614     @patch("black.dump_to_file", dump_to_stderr)
615     def test_remove_empty_parentheses_after_class(self) -> None:
616         source, expected = read_data("class_blank_parentheses")
617         actual = fs(source)
618         self.assertFormatEqual(expected, actual)
619         black.assert_equivalent(source, actual)
620         black.assert_stable(source, actual, black.FileMode())
621
622     @patch("black.dump_to_file", dump_to_stderr)
623     def test_new_line_between_class_and_code(self) -> None:
624         source, expected = read_data("class_methods_new_line")
625         actual = fs(source)
626         self.assertFormatEqual(expected, actual)
627         black.assert_equivalent(source, actual)
628         black.assert_stable(source, actual, black.FileMode())
629
630     @patch("black.dump_to_file", dump_to_stderr)
631     def test_bracket_match(self) -> None:
632         source, expected = read_data("bracketmatch")
633         actual = fs(source)
634         self.assertFormatEqual(expected, actual)
635         black.assert_equivalent(source, actual)
636         black.assert_stable(source, actual, black.FileMode())
637
638     @patch("black.dump_to_file", dump_to_stderr)
639     def test_tuple_assign(self) -> None:
640         source, expected = read_data("tupleassign")
641         actual = fs(source)
642         self.assertFormatEqual(expected, actual)
643         black.assert_equivalent(source, actual)
644         black.assert_stable(source, actual, black.FileMode())
645
646     @patch("black.dump_to_file", dump_to_stderr)
647     def test_beginning_backslash(self) -> None:
648         source, expected = read_data("beginning_backslash")
649         actual = fs(source)
650         self.assertFormatEqual(expected, actual)
651         black.assert_equivalent(source, actual)
652         black.assert_stable(source, actual, black.FileMode())
653
654     def test_tab_comment_indentation(self) -> None:
655         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
656         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
657         self.assertFormatEqual(contents_spc, fs(contents_spc))
658         self.assertFormatEqual(contents_spc, fs(contents_tab))
659
660         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
661         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
662         self.assertFormatEqual(contents_spc, fs(contents_spc))
663         self.assertFormatEqual(contents_spc, fs(contents_tab))
664
665         # mixed tabs and spaces (valid Python 2 code)
666         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
667         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
668         self.assertFormatEqual(contents_spc, fs(contents_spc))
669         self.assertFormatEqual(contents_spc, fs(contents_tab))
670
671         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
672         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
673         self.assertFormatEqual(contents_spc, fs(contents_spc))
674         self.assertFormatEqual(contents_spc, fs(contents_tab))
675
676     def test_report_verbose(self) -> None:
677         report = black.Report(verbose=True)
678         out_lines = []
679         err_lines = []
680
681         def out(msg: str, **kwargs: Any) -> None:
682             out_lines.append(msg)
683
684         def err(msg: str, **kwargs: Any) -> None:
685             err_lines.append(msg)
686
687         with patch("black.out", out), patch("black.err", err):
688             report.done(Path("f1"), black.Changed.NO)
689             self.assertEqual(len(out_lines), 1)
690             self.assertEqual(len(err_lines), 0)
691             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
692             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
693             self.assertEqual(report.return_code, 0)
694             report.done(Path("f2"), black.Changed.YES)
695             self.assertEqual(len(out_lines), 2)
696             self.assertEqual(len(err_lines), 0)
697             self.assertEqual(out_lines[-1], "reformatted f2")
698             self.assertEqual(
699                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
700             )
701             report.done(Path("f3"), black.Changed.CACHED)
702             self.assertEqual(len(out_lines), 3)
703             self.assertEqual(len(err_lines), 0)
704             self.assertEqual(
705                 out_lines[-1], "f3 wasn't modified on disk since last run."
706             )
707             self.assertEqual(
708                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
709             )
710             self.assertEqual(report.return_code, 0)
711             report.check = True
712             self.assertEqual(report.return_code, 1)
713             report.check = False
714             report.failed(Path("e1"), "boom")
715             self.assertEqual(len(out_lines), 3)
716             self.assertEqual(len(err_lines), 1)
717             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "1 file reformatted, 2 files left unchanged, "
721                 "1 file failed to reformat.",
722             )
723             self.assertEqual(report.return_code, 123)
724             report.done(Path("f3"), black.Changed.YES)
725             self.assertEqual(len(out_lines), 4)
726             self.assertEqual(len(err_lines), 1)
727             self.assertEqual(out_lines[-1], "reformatted f3")
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 2 files left unchanged, "
731                 "1 file failed to reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.failed(Path("e2"), "boom")
735             self.assertEqual(len(out_lines), 4)
736             self.assertEqual(len(err_lines), 2)
737             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
738             self.assertEqual(
739                 unstyle(str(report)),
740                 "2 files reformatted, 2 files left unchanged, "
741                 "2 files failed to reformat.",
742             )
743             self.assertEqual(report.return_code, 123)
744             report.path_ignored(Path("wat"), "no match")
745             self.assertEqual(len(out_lines), 5)
746             self.assertEqual(len(err_lines), 2)
747             self.assertEqual(out_lines[-1], "wat ignored: no match")
748             self.assertEqual(
749                 unstyle(str(report)),
750                 "2 files reformatted, 2 files left unchanged, "
751                 "2 files failed to reformat.",
752             )
753             self.assertEqual(report.return_code, 123)
754             report.done(Path("f4"), black.Changed.NO)
755             self.assertEqual(len(out_lines), 6)
756             self.assertEqual(len(err_lines), 2)
757             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
758             self.assertEqual(
759                 unstyle(str(report)),
760                 "2 files reformatted, 3 files left unchanged, "
761                 "2 files failed to reformat.",
762             )
763             self.assertEqual(report.return_code, 123)
764             report.check = True
765             self.assertEqual(
766                 unstyle(str(report)),
767                 "2 files would be reformatted, 3 files would be left unchanged, "
768                 "2 files would fail to reformat.",
769             )
770
771     def test_report_quiet(self) -> None:
772         report = black.Report(quiet=True)
773         out_lines = []
774         err_lines = []
775
776         def out(msg: str, **kwargs: Any) -> None:
777             out_lines.append(msg)
778
779         def err(msg: str, **kwargs: Any) -> None:
780             err_lines.append(msg)
781
782         with patch("black.out", out), patch("black.err", err):
783             report.done(Path("f1"), black.Changed.NO)
784             self.assertEqual(len(out_lines), 0)
785             self.assertEqual(len(err_lines), 0)
786             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
787             self.assertEqual(report.return_code, 0)
788             report.done(Path("f2"), black.Changed.YES)
789             self.assertEqual(len(out_lines), 0)
790             self.assertEqual(len(err_lines), 0)
791             self.assertEqual(
792                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
793             )
794             report.done(Path("f3"), black.Changed.CACHED)
795             self.assertEqual(len(out_lines), 0)
796             self.assertEqual(len(err_lines), 0)
797             self.assertEqual(
798                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
799             )
800             self.assertEqual(report.return_code, 0)
801             report.check = True
802             self.assertEqual(report.return_code, 1)
803             report.check = False
804             report.failed(Path("e1"), "boom")
805             self.assertEqual(len(out_lines), 0)
806             self.assertEqual(len(err_lines), 1)
807             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
808             self.assertEqual(
809                 unstyle(str(report)),
810                 "1 file reformatted, 2 files left unchanged, "
811                 "1 file failed to reformat.",
812             )
813             self.assertEqual(report.return_code, 123)
814             report.done(Path("f3"), black.Changed.YES)
815             self.assertEqual(len(out_lines), 0)
816             self.assertEqual(len(err_lines), 1)
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files reformatted, 2 files left unchanged, "
820                 "1 file failed to reformat.",
821             )
822             self.assertEqual(report.return_code, 123)
823             report.failed(Path("e2"), "boom")
824             self.assertEqual(len(out_lines), 0)
825             self.assertEqual(len(err_lines), 2)
826             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
827             self.assertEqual(
828                 unstyle(str(report)),
829                 "2 files reformatted, 2 files left unchanged, "
830                 "2 files failed to reformat.",
831             )
832             self.assertEqual(report.return_code, 123)
833             report.path_ignored(Path("wat"), "no match")
834             self.assertEqual(len(out_lines), 0)
835             self.assertEqual(len(err_lines), 2)
836             self.assertEqual(
837                 unstyle(str(report)),
838                 "2 files reformatted, 2 files left unchanged, "
839                 "2 files failed to reformat.",
840             )
841             self.assertEqual(report.return_code, 123)
842             report.done(Path("f4"), black.Changed.NO)
843             self.assertEqual(len(out_lines), 0)
844             self.assertEqual(len(err_lines), 2)
845             self.assertEqual(
846                 unstyle(str(report)),
847                 "2 files reformatted, 3 files left unchanged, "
848                 "2 files failed to reformat.",
849             )
850             self.assertEqual(report.return_code, 123)
851             report.check = True
852             self.assertEqual(
853                 unstyle(str(report)),
854                 "2 files would be reformatted, 3 files would be left unchanged, "
855                 "2 files would fail to reformat.",
856             )
857
858     def test_report_normal(self) -> None:
859         report = black.Report()
860         out_lines = []
861         err_lines = []
862
863         def out(msg: str, **kwargs: Any) -> None:
864             out_lines.append(msg)
865
866         def err(msg: str, **kwargs: Any) -> None:
867             err_lines.append(msg)
868
869         with patch("black.out", out), patch("black.err", err):
870             report.done(Path("f1"), black.Changed.NO)
871             self.assertEqual(len(out_lines), 0)
872             self.assertEqual(len(err_lines), 0)
873             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
874             self.assertEqual(report.return_code, 0)
875             report.done(Path("f2"), black.Changed.YES)
876             self.assertEqual(len(out_lines), 1)
877             self.assertEqual(len(err_lines), 0)
878             self.assertEqual(out_lines[-1], "reformatted f2")
879             self.assertEqual(
880                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
881             )
882             report.done(Path("f3"), black.Changed.CACHED)
883             self.assertEqual(len(out_lines), 1)
884             self.assertEqual(len(err_lines), 0)
885             self.assertEqual(out_lines[-1], "reformatted f2")
886             self.assertEqual(
887                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
888             )
889             self.assertEqual(report.return_code, 0)
890             report.check = True
891             self.assertEqual(report.return_code, 1)
892             report.check = False
893             report.failed(Path("e1"), "boom")
894             self.assertEqual(len(out_lines), 1)
895             self.assertEqual(len(err_lines), 1)
896             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
897             self.assertEqual(
898                 unstyle(str(report)),
899                 "1 file reformatted, 2 files left unchanged, "
900                 "1 file failed to reformat.",
901             )
902             self.assertEqual(report.return_code, 123)
903             report.done(Path("f3"), black.Changed.YES)
904             self.assertEqual(len(out_lines), 2)
905             self.assertEqual(len(err_lines), 1)
906             self.assertEqual(out_lines[-1], "reformatted f3")
907             self.assertEqual(
908                 unstyle(str(report)),
909                 "2 files reformatted, 2 files left unchanged, "
910                 "1 file failed to reformat.",
911             )
912             self.assertEqual(report.return_code, 123)
913             report.failed(Path("e2"), "boom")
914             self.assertEqual(len(out_lines), 2)
915             self.assertEqual(len(err_lines), 2)
916             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
917             self.assertEqual(
918                 unstyle(str(report)),
919                 "2 files reformatted, 2 files left unchanged, "
920                 "2 files failed to reformat.",
921             )
922             self.assertEqual(report.return_code, 123)
923             report.path_ignored(Path("wat"), "no match")
924             self.assertEqual(len(out_lines), 2)
925             self.assertEqual(len(err_lines), 2)
926             self.assertEqual(
927                 unstyle(str(report)),
928                 "2 files reformatted, 2 files left unchanged, "
929                 "2 files failed to reformat.",
930             )
931             self.assertEqual(report.return_code, 123)
932             report.done(Path("f4"), black.Changed.NO)
933             self.assertEqual(len(out_lines), 2)
934             self.assertEqual(len(err_lines), 2)
935             self.assertEqual(
936                 unstyle(str(report)),
937                 "2 files reformatted, 3 files left unchanged, "
938                 "2 files failed to reformat.",
939             )
940             self.assertEqual(report.return_code, 123)
941             report.check = True
942             self.assertEqual(
943                 unstyle(str(report)),
944                 "2 files would be reformatted, 3 files would be left unchanged, "
945                 "2 files would fail to reformat.",
946             )
947
948     def test_lib2to3_parse(self) -> None:
949         with self.assertRaises(black.InvalidInput):
950             black.lib2to3_parse("invalid syntax")
951
952         straddling = "x + y"
953         black.lib2to3_parse(straddling)
954         black.lib2to3_parse(straddling, {TargetVersion.PY27})
955         black.lib2to3_parse(straddling, {TargetVersion.PY36})
956         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
957
958         py2_only = "print x"
959         black.lib2to3_parse(py2_only)
960         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
961         with self.assertRaises(black.InvalidInput):
962             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
963         with self.assertRaises(black.InvalidInput):
964             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
965
966         py3_only = "exec(x, end=y)"
967         black.lib2to3_parse(py3_only)
968         with self.assertRaises(black.InvalidInput):
969             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
970         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
971         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
972
973     def test_get_features_used(self) -> None:
974         node = black.lib2to3_parse("def f(*, arg): ...\n")
975         self.assertEqual(black.get_features_used(node), set())
976         node = black.lib2to3_parse("def f(*, arg,): ...\n")
977         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
978         node = black.lib2to3_parse("f(*arg,)\n")
979         self.assertEqual(
980             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
981         )
982         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
983         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
984         node = black.lib2to3_parse("123_456\n")
985         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
986         node = black.lib2to3_parse("123456\n")
987         self.assertEqual(black.get_features_used(node), set())
988         source, expected = read_data("function")
989         node = black.lib2to3_parse(source)
990         expected_features = {
991             Feature.TRAILING_COMMA_IN_CALL,
992             Feature.TRAILING_COMMA_IN_DEF,
993             Feature.F_STRINGS,
994         }
995         self.assertEqual(black.get_features_used(node), expected_features)
996         node = black.lib2to3_parse(expected)
997         self.assertEqual(black.get_features_used(node), expected_features)
998         source, expected = read_data("expression")
999         node = black.lib2to3_parse(source)
1000         self.assertEqual(black.get_features_used(node), set())
1001         node = black.lib2to3_parse(expected)
1002         self.assertEqual(black.get_features_used(node), set())
1003
1004     def test_get_future_imports(self) -> None:
1005         node = black.lib2to3_parse("\n")
1006         self.assertEqual(set(), black.get_future_imports(node))
1007         node = black.lib2to3_parse("from __future__ import black\n")
1008         self.assertEqual({"black"}, black.get_future_imports(node))
1009         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
1010         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1011         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
1012         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
1013         node = black.lib2to3_parse(
1014             "from __future__ import multiple\nfrom __future__ import imports\n"
1015         )
1016         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
1017         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
1018         self.assertEqual({"black"}, black.get_future_imports(node))
1019         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
1020         self.assertEqual({"black"}, black.get_future_imports(node))
1021         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
1022         self.assertEqual(set(), black.get_future_imports(node))
1023         node = black.lib2to3_parse("from some.module import black\n")
1024         self.assertEqual(set(), black.get_future_imports(node))
1025         node = black.lib2to3_parse(
1026             "from __future__ import unicode_literals as _unicode_literals"
1027         )
1028         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
1029         node = black.lib2to3_parse(
1030             "from __future__ import unicode_literals as _lol, print"
1031         )
1032         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
1033
1034     def test_debug_visitor(self) -> None:
1035         source, _ = read_data("debug_visitor.py")
1036         expected, _ = read_data("debug_visitor.out")
1037         out_lines = []
1038         err_lines = []
1039
1040         def out(msg: str, **kwargs: Any) -> None:
1041             out_lines.append(msg)
1042
1043         def err(msg: str, **kwargs: Any) -> None:
1044             err_lines.append(msg)
1045
1046         with patch("black.out", out), patch("black.err", err):
1047             black.DebugVisitor.show(source)
1048         actual = "\n".join(out_lines) + "\n"
1049         log_name = ""
1050         if expected != actual:
1051             log_name = black.dump_to_file(*out_lines)
1052         self.assertEqual(
1053             expected,
1054             actual,
1055             f"AST print out is different. Actual version dumped to {log_name}",
1056         )
1057
1058     def test_format_file_contents(self) -> None:
1059         empty = ""
1060         mode = black.FileMode()
1061         with self.assertRaises(black.NothingChanged):
1062             black.format_file_contents(empty, mode=mode, fast=False)
1063         just_nl = "\n"
1064         with self.assertRaises(black.NothingChanged):
1065             black.format_file_contents(just_nl, mode=mode, fast=False)
1066         same = "j = [1, 2, 3]\n"
1067         with self.assertRaises(black.NothingChanged):
1068             black.format_file_contents(same, mode=mode, fast=False)
1069         different = "j = [1,2,3]"
1070         expected = same
1071         actual = black.format_file_contents(different, mode=mode, fast=False)
1072         self.assertEqual(expected, actual)
1073         invalid = "return if you can"
1074         with self.assertRaises(black.InvalidInput) as e:
1075             black.format_file_contents(invalid, mode=mode, fast=False)
1076         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1077
1078     def test_endmarker(self) -> None:
1079         n = black.lib2to3_parse("\n")
1080         self.assertEqual(n.type, black.syms.file_input)
1081         self.assertEqual(len(n.children), 1)
1082         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1083
1084     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1085     def test_assertFormatEqual(self) -> None:
1086         out_lines = []
1087         err_lines = []
1088
1089         def out(msg: str, **kwargs: Any) -> None:
1090             out_lines.append(msg)
1091
1092         def err(msg: str, **kwargs: Any) -> None:
1093             err_lines.append(msg)
1094
1095         with patch("black.out", out), patch("black.err", err):
1096             with self.assertRaises(AssertionError):
1097                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
1098
1099         out_str = "".join(out_lines)
1100         self.assertTrue("Expected tree:" in out_str)
1101         self.assertTrue("Actual tree:" in out_str)
1102         self.assertEqual("".join(err_lines), "")
1103
1104     def test_cache_broken_file(self) -> None:
1105         mode = black.FileMode()
1106         with cache_dir() as workspace:
1107             cache_file = black.get_cache_file(mode)
1108             with cache_file.open("w") as fobj:
1109                 fobj.write("this is not a pickle")
1110             self.assertEqual(black.read_cache(mode), {})
1111             src = (workspace / "test.py").resolve()
1112             with src.open("w") as fobj:
1113                 fobj.write("print('hello')")
1114             self.invokeBlack([str(src)])
1115             cache = black.read_cache(mode)
1116             self.assertIn(src, cache)
1117
1118     def test_cache_single_file_already_cached(self) -> None:
1119         mode = black.FileMode()
1120         with cache_dir() as workspace:
1121             src = (workspace / "test.py").resolve()
1122             with src.open("w") as fobj:
1123                 fobj.write("print('hello')")
1124             black.write_cache({}, [src], mode)
1125             self.invokeBlack([str(src)])
1126             with src.open("r") as fobj:
1127                 self.assertEqual(fobj.read(), "print('hello')")
1128
1129     @event_loop(close=False)
1130     def test_cache_multiple_files(self) -> None:
1131         mode = black.FileMode()
1132         with cache_dir() as workspace, patch(
1133             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1134         ):
1135             one = (workspace / "one.py").resolve()
1136             with one.open("w") as fobj:
1137                 fobj.write("print('hello')")
1138             two = (workspace / "two.py").resolve()
1139             with two.open("w") as fobj:
1140                 fobj.write("print('hello')")
1141             black.write_cache({}, [one], mode)
1142             self.invokeBlack([str(workspace)])
1143             with one.open("r") as fobj:
1144                 self.assertEqual(fobj.read(), "print('hello')")
1145             with two.open("r") as fobj:
1146                 self.assertEqual(fobj.read(), 'print("hello")\n')
1147             cache = black.read_cache(mode)
1148             self.assertIn(one, cache)
1149             self.assertIn(two, cache)
1150
1151     def test_no_cache_when_writeback_diff(self) -> None:
1152         mode = black.FileMode()
1153         with cache_dir() as workspace:
1154             src = (workspace / "test.py").resolve()
1155             with src.open("w") as fobj:
1156                 fobj.write("print('hello')")
1157             self.invokeBlack([str(src), "--diff"])
1158             cache_file = black.get_cache_file(mode)
1159             self.assertFalse(cache_file.exists())
1160
1161     def test_no_cache_when_stdin(self) -> None:
1162         mode = black.FileMode()
1163         with cache_dir():
1164             result = CliRunner().invoke(
1165                 black.main, ["-"], input=BytesIO(b"print('hello')")
1166             )
1167             self.assertEqual(result.exit_code, 0)
1168             cache_file = black.get_cache_file(mode)
1169             self.assertFalse(cache_file.exists())
1170
1171     def test_read_cache_no_cachefile(self) -> None:
1172         mode = black.FileMode()
1173         with cache_dir():
1174             self.assertEqual(black.read_cache(mode), {})
1175
1176     def test_write_cache_read_cache(self) -> None:
1177         mode = black.FileMode()
1178         with cache_dir() as workspace:
1179             src = (workspace / "test.py").resolve()
1180             src.touch()
1181             black.write_cache({}, [src], mode)
1182             cache = black.read_cache(mode)
1183             self.assertIn(src, cache)
1184             self.assertEqual(cache[src], black.get_cache_info(src))
1185
1186     def test_filter_cached(self) -> None:
1187         with TemporaryDirectory() as workspace:
1188             path = Path(workspace)
1189             uncached = (path / "uncached").resolve()
1190             cached = (path / "cached").resolve()
1191             cached_but_changed = (path / "changed").resolve()
1192             uncached.touch()
1193             cached.touch()
1194             cached_but_changed.touch()
1195             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1196             todo, done = black.filter_cached(
1197                 cache, {uncached, cached, cached_but_changed}
1198             )
1199             self.assertEqual(todo, {uncached, cached_but_changed})
1200             self.assertEqual(done, {cached})
1201
1202     def test_write_cache_creates_directory_if_needed(self) -> None:
1203         mode = black.FileMode()
1204         with cache_dir(exists=False) as workspace:
1205             self.assertFalse(workspace.exists())
1206             black.write_cache({}, [], mode)
1207             self.assertTrue(workspace.exists())
1208
1209     @event_loop(close=False)
1210     def test_failed_formatting_does_not_get_cached(self) -> None:
1211         mode = black.FileMode()
1212         with cache_dir() as workspace, patch(
1213             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1214         ):
1215             failing = (workspace / "failing.py").resolve()
1216             with failing.open("w") as fobj:
1217                 fobj.write("not actually python")
1218             clean = (workspace / "clean.py").resolve()
1219             with clean.open("w") as fobj:
1220                 fobj.write('print("hello")\n')
1221             self.invokeBlack([str(workspace)], exit_code=123)
1222             cache = black.read_cache(mode)
1223             self.assertNotIn(failing, cache)
1224             self.assertIn(clean, cache)
1225
1226     def test_write_cache_write_fail(self) -> None:
1227         mode = black.FileMode()
1228         with cache_dir(), patch.object(Path, "open") as mock:
1229             mock.side_effect = OSError
1230             black.write_cache({}, [], mode)
1231
1232     @event_loop(close=False)
1233     def test_check_diff_use_together(self) -> None:
1234         with cache_dir():
1235             # Files which will be reformatted.
1236             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1237             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1238             # Files which will not be reformatted.
1239             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1240             self.invokeBlack([str(src2), "--diff", "--check"])
1241             # Multi file command.
1242             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1243
1244     def test_no_files(self) -> None:
1245         with cache_dir():
1246             # Without an argument, black exits with error code 0.
1247             self.invokeBlack([])
1248
1249     def test_broken_symlink(self) -> None:
1250         with cache_dir() as workspace:
1251             symlink = workspace / "broken_link.py"
1252             try:
1253                 symlink.symlink_to("nonexistent.py")
1254             except OSError as e:
1255                 self.skipTest(f"Can't create symlinks: {e}")
1256             self.invokeBlack([str(workspace.resolve())])
1257
1258     def test_read_cache_line_lengths(self) -> None:
1259         mode = black.FileMode()
1260         short_mode = black.FileMode(line_length=1)
1261         with cache_dir() as workspace:
1262             path = (workspace / "file.py").resolve()
1263             path.touch()
1264             black.write_cache({}, [path], mode)
1265             one = black.read_cache(mode)
1266             self.assertIn(path, one)
1267             two = black.read_cache(short_mode)
1268             self.assertNotIn(path, two)
1269
1270     def test_tricky_unicode_symbols(self) -> None:
1271         source, expected = read_data("tricky_unicode_symbols")
1272         actual = fs(source)
1273         self.assertFormatEqual(expected, actual)
1274         black.assert_equivalent(source, actual)
1275         black.assert_stable(source, actual, black.FileMode())
1276
1277     def test_single_file_force_pyi(self) -> None:
1278         reg_mode = black.FileMode()
1279         pyi_mode = black.FileMode(is_pyi=True)
1280         contents, expected = read_data("force_pyi")
1281         with cache_dir() as workspace:
1282             path = (workspace / "file.py").resolve()
1283             with open(path, "w") as fh:
1284                 fh.write(contents)
1285             self.invokeBlack([str(path), "--pyi"])
1286             with open(path, "r") as fh:
1287                 actual = fh.read()
1288             # verify cache with --pyi is separate
1289             pyi_cache = black.read_cache(pyi_mode)
1290             self.assertIn(path, pyi_cache)
1291             normal_cache = black.read_cache(reg_mode)
1292             self.assertNotIn(path, normal_cache)
1293         self.assertEqual(actual, expected)
1294
1295     @event_loop(close=False)
1296     def test_multi_file_force_pyi(self) -> None:
1297         reg_mode = black.FileMode()
1298         pyi_mode = black.FileMode(is_pyi=True)
1299         contents, expected = read_data("force_pyi")
1300         with cache_dir() as workspace:
1301             paths = [
1302                 (workspace / "file1.py").resolve(),
1303                 (workspace / "file2.py").resolve(),
1304             ]
1305             for path in paths:
1306                 with open(path, "w") as fh:
1307                     fh.write(contents)
1308             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1309             for path in paths:
1310                 with open(path, "r") as fh:
1311                     actual = fh.read()
1312                 self.assertEqual(actual, expected)
1313             # verify cache with --pyi is separate
1314             pyi_cache = black.read_cache(pyi_mode)
1315             normal_cache = black.read_cache(reg_mode)
1316             for path in paths:
1317                 self.assertIn(path, pyi_cache)
1318                 self.assertNotIn(path, normal_cache)
1319
1320     def test_pipe_force_pyi(self) -> None:
1321         source, expected = read_data("force_pyi")
1322         result = CliRunner().invoke(
1323             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1324         )
1325         self.assertEqual(result.exit_code, 0)
1326         actual = result.output
1327         self.assertFormatEqual(actual, expected)
1328
1329     def test_single_file_force_py36(self) -> None:
1330         reg_mode = black.FileMode()
1331         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1332         source, expected = read_data("force_py36")
1333         with cache_dir() as workspace:
1334             path = (workspace / "file.py").resolve()
1335             with open(path, "w") as fh:
1336                 fh.write(source)
1337             self.invokeBlack([str(path), *PY36_ARGS])
1338             with open(path, "r") as fh:
1339                 actual = fh.read()
1340             # verify cache with --target-version is separate
1341             py36_cache = black.read_cache(py36_mode)
1342             self.assertIn(path, py36_cache)
1343             normal_cache = black.read_cache(reg_mode)
1344             self.assertNotIn(path, normal_cache)
1345         self.assertEqual(actual, expected)
1346
1347     @event_loop(close=False)
1348     def test_multi_file_force_py36(self) -> None:
1349         reg_mode = black.FileMode()
1350         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1351         source, expected = read_data("force_py36")
1352         with cache_dir() as workspace:
1353             paths = [
1354                 (workspace / "file1.py").resolve(),
1355                 (workspace / "file2.py").resolve(),
1356             ]
1357             for path in paths:
1358                 with open(path, "w") as fh:
1359                     fh.write(source)
1360             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1361             for path in paths:
1362                 with open(path, "r") as fh:
1363                     actual = fh.read()
1364                 self.assertEqual(actual, expected)
1365             # verify cache with --target-version is separate
1366             pyi_cache = black.read_cache(py36_mode)
1367             normal_cache = black.read_cache(reg_mode)
1368             for path in paths:
1369                 self.assertIn(path, pyi_cache)
1370                 self.assertNotIn(path, normal_cache)
1371
1372     def test_collections(self) -> None:
1373         source, expected = read_data("collections")
1374         actual = fs(source)
1375         self.assertFormatEqual(expected, actual)
1376         black.assert_equivalent(source, actual)
1377         black.assert_stable(source, actual, black.FileMode())
1378
1379     def test_pipe_force_py36(self) -> None:
1380         source, expected = read_data("force_py36")
1381         result = CliRunner().invoke(
1382             black.main,
1383             ["-", "-q", "--target-version=py36"],
1384             input=BytesIO(source.encode("utf8")),
1385         )
1386         self.assertEqual(result.exit_code, 0)
1387         actual = result.output
1388         self.assertFormatEqual(actual, expected)
1389
1390     def test_include_exclude(self) -> None:
1391         path = THIS_DIR / "data" / "include_exclude_tests"
1392         include = re.compile(r"\.pyi?$")
1393         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1394         report = black.Report()
1395         sources: List[Path] = []
1396         expected = [
1397             Path(path / "b/dont_exclude/a.py"),
1398             Path(path / "b/dont_exclude/a.pyi"),
1399         ]
1400         this_abs = THIS_DIR.resolve()
1401         sources.extend(
1402             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1403         )
1404         self.assertEqual(sorted(expected), sorted(sources))
1405
1406     def test_empty_include(self) -> None:
1407         path = THIS_DIR / "data" / "include_exclude_tests"
1408         report = black.Report()
1409         empty = re.compile(r"")
1410         sources: List[Path] = []
1411         expected = [
1412             Path(path / "b/exclude/a.pie"),
1413             Path(path / "b/exclude/a.py"),
1414             Path(path / "b/exclude/a.pyi"),
1415             Path(path / "b/dont_exclude/a.pie"),
1416             Path(path / "b/dont_exclude/a.py"),
1417             Path(path / "b/dont_exclude/a.pyi"),
1418             Path(path / "b/.definitely_exclude/a.pie"),
1419             Path(path / "b/.definitely_exclude/a.py"),
1420             Path(path / "b/.definitely_exclude/a.pyi"),
1421         ]
1422         this_abs = THIS_DIR.resolve()
1423         sources.extend(
1424             black.gen_python_files_in_dir(
1425                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1426             )
1427         )
1428         self.assertEqual(sorted(expected), sorted(sources))
1429
1430     def test_empty_exclude(self) -> None:
1431         path = THIS_DIR / "data" / "include_exclude_tests"
1432         report = black.Report()
1433         empty = re.compile(r"")
1434         sources: List[Path] = []
1435         expected = [
1436             Path(path / "b/dont_exclude/a.py"),
1437             Path(path / "b/dont_exclude/a.pyi"),
1438             Path(path / "b/exclude/a.py"),
1439             Path(path / "b/exclude/a.pyi"),
1440             Path(path / "b/.definitely_exclude/a.py"),
1441             Path(path / "b/.definitely_exclude/a.pyi"),
1442         ]
1443         this_abs = THIS_DIR.resolve()
1444         sources.extend(
1445             black.gen_python_files_in_dir(
1446                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1447             )
1448         )
1449         self.assertEqual(sorted(expected), sorted(sources))
1450
1451     def test_invalid_include_exclude(self) -> None:
1452         for option in ["--include", "--exclude"]:
1453             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1454
1455     def test_preserves_line_endings(self) -> None:
1456         with TemporaryDirectory() as workspace:
1457             test_file = Path(workspace) / "test.py"
1458             for nl in ["\n", "\r\n"]:
1459                 contents = nl.join(["def f(  ):", "    pass"])
1460                 test_file.write_bytes(contents.encode())
1461                 ff(test_file, write_back=black.WriteBack.YES)
1462                 updated_contents: bytes = test_file.read_bytes()
1463                 self.assertIn(nl.encode(), updated_contents)
1464                 if nl == "\n":
1465                     self.assertNotIn(b"\r\n", updated_contents)
1466
1467     def test_preserves_line_endings_via_stdin(self) -> None:
1468         for nl in ["\n", "\r\n"]:
1469             contents = nl.join(["def f(  ):", "    pass"])
1470             runner = BlackRunner()
1471             result = runner.invoke(
1472                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1473             )
1474             self.assertEqual(result.exit_code, 0)
1475             output = runner.stdout_bytes
1476             self.assertIn(nl.encode("utf8"), output)
1477             if nl == "\n":
1478                 self.assertNotIn(b"\r\n", output)
1479
1480     def test_assert_equivalent_different_asts(self) -> None:
1481         with self.assertRaises(AssertionError):
1482             black.assert_equivalent("{}", "None")
1483
1484     def test_symlink_out_of_root_directory(self) -> None:
1485         path = MagicMock()
1486         root = THIS_DIR
1487         child = MagicMock()
1488         include = re.compile(black.DEFAULT_INCLUDES)
1489         exclude = re.compile(black.DEFAULT_EXCLUDES)
1490         report = black.Report()
1491         # `child` should behave like a symlink which resolved path is clearly
1492         # outside of the `root` directory.
1493         path.iterdir.return_value = [child]
1494         child.resolve.return_value = Path("/a/b/c")
1495         child.is_symlink.return_value = True
1496         try:
1497             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1498         except ValueError as ve:
1499             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1500         path.iterdir.assert_called_once()
1501         child.resolve.assert_called_once()
1502         child.is_symlink.assert_called_once()
1503         # `child` should behave like a strange file which resolved path is clearly
1504         # outside of the `root` directory.
1505         child.is_symlink.return_value = False
1506         with self.assertRaises(ValueError):
1507             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1508         path.iterdir.assert_called()
1509         self.assertEqual(path.iterdir.call_count, 2)
1510         child.resolve.assert_called()
1511         self.assertEqual(child.resolve.call_count, 2)
1512         child.is_symlink.assert_called()
1513         self.assertEqual(child.is_symlink.call_count, 2)
1514
1515     def test_shhh_click(self) -> None:
1516         try:
1517             from click import _unicodefun  # type: ignore
1518         except ModuleNotFoundError:
1519             self.skipTest("Incompatible Click version")
1520         if not hasattr(_unicodefun, "_verify_python3_env"):
1521             self.skipTest("Incompatible Click version")
1522         # First, let's see if Click is crashing with a preferred ASCII charset.
1523         with patch("locale.getpreferredencoding") as gpe:
1524             gpe.return_value = "ASCII"
1525             with self.assertRaises(RuntimeError):
1526                 _unicodefun._verify_python3_env()
1527         # Now, let's silence Click...
1528         black.patch_click()
1529         # ...and confirm it's silent.
1530         with patch("locale.getpreferredencoding") as gpe:
1531             gpe.return_value = "ASCII"
1532             try:
1533                 _unicodefun._verify_python3_env()
1534             except RuntimeError as re:
1535                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1536
1537     def test_root_logger_not_used_directly(self) -> None:
1538         def fail(*args: Any, **kwargs: Any) -> None:
1539             self.fail("Record created with root logger")
1540
1541         with patch.multiple(
1542             logging.root,
1543             debug=fail,
1544             info=fail,
1545             warning=fail,
1546             error=fail,
1547             critical=fail,
1548             log=fail,
1549         ):
1550             ff(THIS_FILE)
1551
1552     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1553     def test_blackd_main(self) -> None:
1554         with patch("blackd.web.run_app"):
1555             result = CliRunner().invoke(blackd.main, [])
1556             if result.exception is not None:
1557                 raise result.exception
1558             self.assertEqual(result.exit_code, 0)
1559
1560
1561 class BlackDTestCase(AioHTTPTestCase):
1562     async def get_application(self) -> web.Application:
1563         return blackd.make_app()
1564
1565     # TODO: remove these decorators once the below is released
1566     # https://github.com/aio-libs/aiohttp/pull/3727
1567     @skip_if_exception("ClientOSError")
1568     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1569     @unittest_run_loop
1570     async def test_blackd_request_needs_formatting(self) -> None:
1571         response = await self.client.post("/", data=b"print('hello world')")
1572         self.assertEqual(response.status, 200)
1573         self.assertEqual(response.charset, "utf8")
1574         self.assertEqual(await response.read(), b'print("hello world")\n')
1575
1576     @skip_if_exception("ClientOSError")
1577     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1578     @unittest_run_loop
1579     async def test_blackd_request_no_change(self) -> None:
1580         response = await self.client.post("/", data=b'print("hello world")\n')
1581         self.assertEqual(response.status, 204)
1582         self.assertEqual(await response.read(), b"")
1583
1584     @skip_if_exception("ClientOSError")
1585     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1586     @unittest_run_loop
1587     async def test_blackd_request_syntax_error(self) -> None:
1588         response = await self.client.post("/", data=b"what even ( is")
1589         self.assertEqual(response.status, 400)
1590         content = await response.text()
1591         self.assertTrue(
1592             content.startswith("Cannot parse"),
1593             msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1594         )
1595
1596     @skip_if_exception("ClientOSError")
1597     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1598     @unittest_run_loop
1599     async def test_blackd_unsupported_version(self) -> None:
1600         response = await self.client.post(
1601             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
1602         )
1603         self.assertEqual(response.status, 501)
1604
1605     @skip_if_exception("ClientOSError")
1606     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1607     @unittest_run_loop
1608     async def test_blackd_supported_version(self) -> None:
1609         response = await self.client.post(
1610             "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
1611         )
1612         self.assertEqual(response.status, 200)
1613
1614     @skip_if_exception("ClientOSError")
1615     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1616     @unittest_run_loop
1617     async def test_blackd_invalid_python_variant(self) -> None:
1618         async def check(header_value: str, expected_status: int = 400) -> None:
1619             response = await self.client.post(
1620                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1621             )
1622             self.assertEqual(response.status, expected_status)
1623
1624         await check("lol")
1625         await check("ruby3.5")
1626         await check("pyi3.6")
1627         await check("py1.5")
1628         await check("2.8")
1629         await check("py2.8")
1630         await check("3.0")
1631         await check("pypy3.0")
1632         await check("jython3.4")
1633
1634     @skip_if_exception("ClientOSError")
1635     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1636     @unittest_run_loop
1637     async def test_blackd_pyi(self) -> None:
1638         source, expected = read_data("stub.pyi")
1639         response = await self.client.post(
1640             "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1641         )
1642         self.assertEqual(response.status, 200)
1643         self.assertEqual(await response.text(), expected)
1644
1645     @skip_if_exception("ClientOSError")
1646     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1647     @unittest_run_loop
1648     async def test_blackd_python_variant(self) -> None:
1649         code = (
1650             "def f(\n"
1651             "    and_has_a_bunch_of,\n"
1652             "    very_long_arguments_too,\n"
1653             "    and_lots_of_them_as_well_lol,\n"
1654             "    **and_very_long_keyword_arguments\n"
1655             "):\n"
1656             "    pass\n"
1657         )
1658
1659         async def check(header_value: str, expected_status: int) -> None:
1660             response = await self.client.post(
1661                 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1662             )
1663             self.assertEqual(
1664                 response.status, expected_status, msg=await response.text()
1665             )
1666
1667         await check("3.6", 200)
1668         await check("py3.6", 200)
1669         await check("3.6,3.7", 200)
1670         await check("3.6,py3.7", 200)
1671         await check("py36,py37", 200)
1672         await check("36", 200)
1673         await check("3.6.4", 200)
1674
1675         await check("2", 204)
1676         await check("2.7", 204)
1677         await check("py2.7", 204)
1678         await check("3.4", 204)
1679         await check("py3.4", 204)
1680         await check("py34,py36", 204)
1681         await check("34", 204)
1682
1683     @skip_if_exception("ClientOSError")
1684     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1685     @unittest_run_loop
1686     async def test_blackd_line_length(self) -> None:
1687         response = await self.client.post(
1688             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1689         )
1690         self.assertEqual(response.status, 200)
1691
1692     @skip_if_exception("ClientOSError")
1693     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1694     @unittest_run_loop
1695     async def test_blackd_invalid_line_length(self) -> None:
1696         response = await self.client.post(
1697             "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
1698         )
1699         self.assertEqual(response.status, 400)
1700
1701     @skip_if_exception("ClientOSError")
1702     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1703     @unittest_run_loop
1704     async def test_blackd_response_black_version_header(self) -> None:
1705         response = await self.client.post("/")
1706         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
1707
1708
1709 if __name__ == "__main__":
1710     unittest.main(module="test_black")