]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[blib2to3] Fixed a typo and removed an unused import. (#848)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments7(self) -> None:
393         source, expected = read_data("comments7")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_comment_after_escaped_newline(self) -> None:
401         source, expected = read_data("comment_after_escaped_newline")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_cantfit(self) -> None:
409         source, expected = read_data("cantfit")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_import_spacing(self) -> None:
417         source, expected = read_data("import_spacing")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_composition(self) -> None:
425         source, expected = read_data("composition")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_empty_lines(self) -> None:
433         source, expected = read_data("empty_lines")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_remove_parens(self) -> None:
441         source, expected = read_data("remove_parens")
442         actual = fs(source)
443         self.assertFormatEqual(expected, actual)
444         black.assert_equivalent(source, actual)
445         black.assert_stable(source, actual, black.FileMode())
446
447     @patch("black.dump_to_file", dump_to_stderr)
448     def test_string_prefixes(self) -> None:
449         source, expected = read_data("string_prefixes")
450         actual = fs(source)
451         self.assertFormatEqual(expected, actual)
452         black.assert_equivalent(source, actual)
453         black.assert_stable(source, actual, black.FileMode())
454
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_numeric_literals(self) -> None:
457         source, expected = read_data("numeric_literals")
458         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
459         actual = fs(source, mode=mode)
460         self.assertFormatEqual(expected, actual)
461         black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, mode)
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_numeric_literals_ignoring_underscores(self) -> None:
466         source, expected = read_data("numeric_literals_skip_underscores")
467         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
468         actual = fs(source, mode=mode)
469         self.assertFormatEqual(expected, actual)
470         black.assert_equivalent(source, actual)
471         black.assert_stable(source, actual, mode)
472
473     @patch("black.dump_to_file", dump_to_stderr)
474     def test_numeric_literals_py2(self) -> None:
475         source, expected = read_data("numeric_literals_py2")
476         actual = fs(source)
477         self.assertFormatEqual(expected, actual)
478         black.assert_stable(source, actual, black.FileMode())
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python2(self) -> None:
482         source, expected = read_data("python2")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, black.FileMode())
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_python2_print_function(self) -> None:
490         source, expected = read_data("python2_print_function")
491         mode = black.FileMode(target_versions={TargetVersion.PY27})
492         actual = fs(source, mode=mode)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, mode)
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_python2_unicode_literals(self) -> None:
499         source, expected = read_data("python2_unicode_literals")
500         actual = fs(source)
501         self.assertFormatEqual(expected, actual)
502         black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_stub(self) -> None:
507         mode = black.FileMode(is_pyi=True)
508         source, expected = read_data("stub.pyi")
509         actual = fs(source, mode=mode)
510         self.assertFormatEqual(expected, actual)
511         black.assert_stable(source, actual, mode)
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_async_as_identifier(self) -> None:
515         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
516         source, expected = read_data("async_as_identifier")
517         actual = fs(source)
518         self.assertFormatEqual(expected, actual)
519         major, minor = sys.version_info[:2]
520         if major < 3 or (major <= 3 and minor < 7):
521             black.assert_equivalent(source, actual)
522         black.assert_stable(source, actual, black.FileMode())
523         # ensure black can parse this when the target is 3.6
524         self.invokeBlack([str(source_path), "--target-version", "py36"])
525         # but not on 3.7, because async/await is no longer an identifier
526         self.invokeBlack([str(source_path), "--target-version", "py37"], exit_code=123)
527
528     @patch("black.dump_to_file", dump_to_stderr)
529     def test_python37(self) -> None:
530         source_path = (THIS_DIR / "data" / "python37.py").resolve()
531         source, expected = read_data("python37")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         major, minor = sys.version_info[:2]
535         if major > 3 or (major == 3 and minor >= 7):
536             black.assert_equivalent(source, actual)
537         black.assert_stable(source, actual, black.FileMode())
538         # ensure black can parse this when the target is 3.7
539         self.invokeBlack([str(source_path), "--target-version", "py37"])
540         # but not on 3.6, because we use async as a reserved keyword
541         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
542
543     @patch("black.dump_to_file", dump_to_stderr)
544     def test_fmtonoff(self) -> None:
545         source, expected = read_data("fmtonoff")
546         actual = fs(source)
547         self.assertFormatEqual(expected, actual)
548         black.assert_equivalent(source, actual)
549         black.assert_stable(source, actual, black.FileMode())
550
551     @patch("black.dump_to_file", dump_to_stderr)
552     def test_fmtonoff2(self) -> None:
553         source, expected = read_data("fmtonoff2")
554         actual = fs(source)
555         self.assertFormatEqual(expected, actual)
556         black.assert_equivalent(source, actual)
557         black.assert_stable(source, actual, black.FileMode())
558
559     @patch("black.dump_to_file", dump_to_stderr)
560     def test_remove_empty_parentheses_after_class(self) -> None:
561         source, expected = read_data("class_blank_parentheses")
562         actual = fs(source)
563         self.assertFormatEqual(expected, actual)
564         black.assert_equivalent(source, actual)
565         black.assert_stable(source, actual, black.FileMode())
566
567     @patch("black.dump_to_file", dump_to_stderr)
568     def test_new_line_between_class_and_code(self) -> None:
569         source, expected = read_data("class_methods_new_line")
570         actual = fs(source)
571         self.assertFormatEqual(expected, actual)
572         black.assert_equivalent(source, actual)
573         black.assert_stable(source, actual, black.FileMode())
574
575     @patch("black.dump_to_file", dump_to_stderr)
576     def test_bracket_match(self) -> None:
577         source, expected = read_data("bracketmatch")
578         actual = fs(source)
579         self.assertFormatEqual(expected, actual)
580         black.assert_equivalent(source, actual)
581         black.assert_stable(source, actual, black.FileMode())
582
583     @patch("black.dump_to_file", dump_to_stderr)
584     def test_tuple_assign(self) -> None:
585         source, expected = read_data("tupleassign")
586         actual = fs(source)
587         self.assertFormatEqual(expected, actual)
588         black.assert_equivalent(source, actual)
589         black.assert_stable(source, actual, black.FileMode())
590
591     def test_tab_comment_indentation(self) -> None:
592         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
593         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
594         self.assertFormatEqual(contents_spc, fs(contents_spc))
595         self.assertFormatEqual(contents_spc, fs(contents_tab))
596
597         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
598         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
599         self.assertFormatEqual(contents_spc, fs(contents_spc))
600         self.assertFormatEqual(contents_spc, fs(contents_tab))
601
602         # mixed tabs and spaces (valid Python 2 code)
603         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
604         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
605         self.assertFormatEqual(contents_spc, fs(contents_spc))
606         self.assertFormatEqual(contents_spc, fs(contents_tab))
607
608         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
609         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
610         self.assertFormatEqual(contents_spc, fs(contents_spc))
611         self.assertFormatEqual(contents_spc, fs(contents_tab))
612
613     def test_report_verbose(self) -> None:
614         report = black.Report(verbose=True)
615         out_lines = []
616         err_lines = []
617
618         def out(msg: str, **kwargs: Any) -> None:
619             out_lines.append(msg)
620
621         def err(msg: str, **kwargs: Any) -> None:
622             err_lines.append(msg)
623
624         with patch("black.out", out), patch("black.err", err):
625             report.done(Path("f1"), black.Changed.NO)
626             self.assertEqual(len(out_lines), 1)
627             self.assertEqual(len(err_lines), 0)
628             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
629             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
630             self.assertEqual(report.return_code, 0)
631             report.done(Path("f2"), black.Changed.YES)
632             self.assertEqual(len(out_lines), 2)
633             self.assertEqual(len(err_lines), 0)
634             self.assertEqual(out_lines[-1], "reformatted f2")
635             self.assertEqual(
636                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
637             )
638             report.done(Path("f3"), black.Changed.CACHED)
639             self.assertEqual(len(out_lines), 3)
640             self.assertEqual(len(err_lines), 0)
641             self.assertEqual(
642                 out_lines[-1], "f3 wasn't modified on disk since last run."
643             )
644             self.assertEqual(
645                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
646             )
647             self.assertEqual(report.return_code, 0)
648             report.check = True
649             self.assertEqual(report.return_code, 1)
650             report.check = False
651             report.failed(Path("e1"), "boom")
652             self.assertEqual(len(out_lines), 3)
653             self.assertEqual(len(err_lines), 1)
654             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
655             self.assertEqual(
656                 unstyle(str(report)),
657                 "1 file reformatted, 2 files left unchanged, "
658                 "1 file failed to reformat.",
659             )
660             self.assertEqual(report.return_code, 123)
661             report.done(Path("f3"), black.Changed.YES)
662             self.assertEqual(len(out_lines), 4)
663             self.assertEqual(len(err_lines), 1)
664             self.assertEqual(out_lines[-1], "reformatted f3")
665             self.assertEqual(
666                 unstyle(str(report)),
667                 "2 files reformatted, 2 files left unchanged, "
668                 "1 file failed to reformat.",
669             )
670             self.assertEqual(report.return_code, 123)
671             report.failed(Path("e2"), "boom")
672             self.assertEqual(len(out_lines), 4)
673             self.assertEqual(len(err_lines), 2)
674             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
675             self.assertEqual(
676                 unstyle(str(report)),
677                 "2 files reformatted, 2 files left unchanged, "
678                 "2 files failed to reformat.",
679             )
680             self.assertEqual(report.return_code, 123)
681             report.path_ignored(Path("wat"), "no match")
682             self.assertEqual(len(out_lines), 5)
683             self.assertEqual(len(err_lines), 2)
684             self.assertEqual(out_lines[-1], "wat ignored: no match")
685             self.assertEqual(
686                 unstyle(str(report)),
687                 "2 files reformatted, 2 files left unchanged, "
688                 "2 files failed to reformat.",
689             )
690             self.assertEqual(report.return_code, 123)
691             report.done(Path("f4"), black.Changed.NO)
692             self.assertEqual(len(out_lines), 6)
693             self.assertEqual(len(err_lines), 2)
694             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
695             self.assertEqual(
696                 unstyle(str(report)),
697                 "2 files reformatted, 3 files left unchanged, "
698                 "2 files failed to reformat.",
699             )
700             self.assertEqual(report.return_code, 123)
701             report.check = True
702             self.assertEqual(
703                 unstyle(str(report)),
704                 "2 files would be reformatted, 3 files would be left unchanged, "
705                 "2 files would fail to reformat.",
706             )
707
708     def test_report_quiet(self) -> None:
709         report = black.Report(quiet=True)
710         out_lines = []
711         err_lines = []
712
713         def out(msg: str, **kwargs: Any) -> None:
714             out_lines.append(msg)
715
716         def err(msg: str, **kwargs: Any) -> None:
717             err_lines.append(msg)
718
719         with patch("black.out", out), patch("black.err", err):
720             report.done(Path("f1"), black.Changed.NO)
721             self.assertEqual(len(out_lines), 0)
722             self.assertEqual(len(err_lines), 0)
723             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
724             self.assertEqual(report.return_code, 0)
725             report.done(Path("f2"), black.Changed.YES)
726             self.assertEqual(len(out_lines), 0)
727             self.assertEqual(len(err_lines), 0)
728             self.assertEqual(
729                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
730             )
731             report.done(Path("f3"), black.Changed.CACHED)
732             self.assertEqual(len(out_lines), 0)
733             self.assertEqual(len(err_lines), 0)
734             self.assertEqual(
735                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
736             )
737             self.assertEqual(report.return_code, 0)
738             report.check = True
739             self.assertEqual(report.return_code, 1)
740             report.check = False
741             report.failed(Path("e1"), "boom")
742             self.assertEqual(len(out_lines), 0)
743             self.assertEqual(len(err_lines), 1)
744             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
745             self.assertEqual(
746                 unstyle(str(report)),
747                 "1 file reformatted, 2 files left unchanged, "
748                 "1 file failed to reformat.",
749             )
750             self.assertEqual(report.return_code, 123)
751             report.done(Path("f3"), black.Changed.YES)
752             self.assertEqual(len(out_lines), 0)
753             self.assertEqual(len(err_lines), 1)
754             self.assertEqual(
755                 unstyle(str(report)),
756                 "2 files reformatted, 2 files left unchanged, "
757                 "1 file failed to reformat.",
758             )
759             self.assertEqual(report.return_code, 123)
760             report.failed(Path("e2"), "boom")
761             self.assertEqual(len(out_lines), 0)
762             self.assertEqual(len(err_lines), 2)
763             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
764             self.assertEqual(
765                 unstyle(str(report)),
766                 "2 files reformatted, 2 files left unchanged, "
767                 "2 files failed to reformat.",
768             )
769             self.assertEqual(report.return_code, 123)
770             report.path_ignored(Path("wat"), "no match")
771             self.assertEqual(len(out_lines), 0)
772             self.assertEqual(len(err_lines), 2)
773             self.assertEqual(
774                 unstyle(str(report)),
775                 "2 files reformatted, 2 files left unchanged, "
776                 "2 files failed to reformat.",
777             )
778             self.assertEqual(report.return_code, 123)
779             report.done(Path("f4"), black.Changed.NO)
780             self.assertEqual(len(out_lines), 0)
781             self.assertEqual(len(err_lines), 2)
782             self.assertEqual(
783                 unstyle(str(report)),
784                 "2 files reformatted, 3 files left unchanged, "
785                 "2 files failed to reformat.",
786             )
787             self.assertEqual(report.return_code, 123)
788             report.check = True
789             self.assertEqual(
790                 unstyle(str(report)),
791                 "2 files would be reformatted, 3 files would be left unchanged, "
792                 "2 files would fail to reformat.",
793             )
794
795     def test_report_normal(self) -> None:
796         report = black.Report()
797         out_lines = []
798         err_lines = []
799
800         def out(msg: str, **kwargs: Any) -> None:
801             out_lines.append(msg)
802
803         def err(msg: str, **kwargs: Any) -> None:
804             err_lines.append(msg)
805
806         with patch("black.out", out), patch("black.err", err):
807             report.done(Path("f1"), black.Changed.NO)
808             self.assertEqual(len(out_lines), 0)
809             self.assertEqual(len(err_lines), 0)
810             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
811             self.assertEqual(report.return_code, 0)
812             report.done(Path("f2"), black.Changed.YES)
813             self.assertEqual(len(out_lines), 1)
814             self.assertEqual(len(err_lines), 0)
815             self.assertEqual(out_lines[-1], "reformatted f2")
816             self.assertEqual(
817                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
818             )
819             report.done(Path("f3"), black.Changed.CACHED)
820             self.assertEqual(len(out_lines), 1)
821             self.assertEqual(len(err_lines), 0)
822             self.assertEqual(out_lines[-1], "reformatted f2")
823             self.assertEqual(
824                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
825             )
826             self.assertEqual(report.return_code, 0)
827             report.check = True
828             self.assertEqual(report.return_code, 1)
829             report.check = False
830             report.failed(Path("e1"), "boom")
831             self.assertEqual(len(out_lines), 1)
832             self.assertEqual(len(err_lines), 1)
833             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
834             self.assertEqual(
835                 unstyle(str(report)),
836                 "1 file reformatted, 2 files left unchanged, "
837                 "1 file failed to reformat.",
838             )
839             self.assertEqual(report.return_code, 123)
840             report.done(Path("f3"), black.Changed.YES)
841             self.assertEqual(len(out_lines), 2)
842             self.assertEqual(len(err_lines), 1)
843             self.assertEqual(out_lines[-1], "reformatted f3")
844             self.assertEqual(
845                 unstyle(str(report)),
846                 "2 files reformatted, 2 files left unchanged, "
847                 "1 file failed to reformat.",
848             )
849             self.assertEqual(report.return_code, 123)
850             report.failed(Path("e2"), "boom")
851             self.assertEqual(len(out_lines), 2)
852             self.assertEqual(len(err_lines), 2)
853             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
854             self.assertEqual(
855                 unstyle(str(report)),
856                 "2 files reformatted, 2 files left unchanged, "
857                 "2 files failed to reformat.",
858             )
859             self.assertEqual(report.return_code, 123)
860             report.path_ignored(Path("wat"), "no match")
861             self.assertEqual(len(out_lines), 2)
862             self.assertEqual(len(err_lines), 2)
863             self.assertEqual(
864                 unstyle(str(report)),
865                 "2 files reformatted, 2 files left unchanged, "
866                 "2 files failed to reformat.",
867             )
868             self.assertEqual(report.return_code, 123)
869             report.done(Path("f4"), black.Changed.NO)
870             self.assertEqual(len(out_lines), 2)
871             self.assertEqual(len(err_lines), 2)
872             self.assertEqual(
873                 unstyle(str(report)),
874                 "2 files reformatted, 3 files left unchanged, "
875                 "2 files failed to reformat.",
876             )
877             self.assertEqual(report.return_code, 123)
878             report.check = True
879             self.assertEqual(
880                 unstyle(str(report)),
881                 "2 files would be reformatted, 3 files would be left unchanged, "
882                 "2 files would fail to reformat.",
883             )
884
885     def test_lib2to3_parse(self) -> None:
886         with self.assertRaises(black.InvalidInput):
887             black.lib2to3_parse("invalid syntax")
888
889         straddling = "x + y"
890         black.lib2to3_parse(straddling)
891         black.lib2to3_parse(straddling, {TargetVersion.PY27})
892         black.lib2to3_parse(straddling, {TargetVersion.PY36})
893         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
894
895         py2_only = "print x"
896         black.lib2to3_parse(py2_only)
897         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
898         with self.assertRaises(black.InvalidInput):
899             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
900         with self.assertRaises(black.InvalidInput):
901             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
902
903         py3_only = "exec(x, end=y)"
904         black.lib2to3_parse(py3_only)
905         with self.assertRaises(black.InvalidInput):
906             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
907         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
908         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
909
910     def test_get_features_used(self) -> None:
911         node = black.lib2to3_parse("def f(*, arg): ...\n")
912         self.assertEqual(black.get_features_used(node), set())
913         node = black.lib2to3_parse("def f(*, arg,): ...\n")
914         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
915         node = black.lib2to3_parse("f(*arg,)\n")
916         self.assertEqual(
917             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
918         )
919         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
920         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
921         node = black.lib2to3_parse("123_456\n")
922         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
923         node = black.lib2to3_parse("123456\n")
924         self.assertEqual(black.get_features_used(node), set())
925         source, expected = read_data("function")
926         node = black.lib2to3_parse(source)
927         expected_features = {
928             Feature.TRAILING_COMMA_IN_CALL,
929             Feature.TRAILING_COMMA_IN_DEF,
930             Feature.F_STRINGS,
931         }
932         self.assertEqual(black.get_features_used(node), expected_features)
933         node = black.lib2to3_parse(expected)
934         self.assertEqual(black.get_features_used(node), expected_features)
935         source, expected = read_data("expression")
936         node = black.lib2to3_parse(source)
937         self.assertEqual(black.get_features_used(node), set())
938         node = black.lib2to3_parse(expected)
939         self.assertEqual(black.get_features_used(node), set())
940
941     def test_get_future_imports(self) -> None:
942         node = black.lib2to3_parse("\n")
943         self.assertEqual(set(), black.get_future_imports(node))
944         node = black.lib2to3_parse("from __future__ import black\n")
945         self.assertEqual({"black"}, black.get_future_imports(node))
946         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
947         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
948         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
949         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
950         node = black.lib2to3_parse(
951             "from __future__ import multiple\nfrom __future__ import imports\n"
952         )
953         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
954         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
955         self.assertEqual({"black"}, black.get_future_imports(node))
956         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
957         self.assertEqual({"black"}, black.get_future_imports(node))
958         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
959         self.assertEqual(set(), black.get_future_imports(node))
960         node = black.lib2to3_parse("from some.module import black\n")
961         self.assertEqual(set(), black.get_future_imports(node))
962         node = black.lib2to3_parse(
963             "from __future__ import unicode_literals as _unicode_literals"
964         )
965         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
966         node = black.lib2to3_parse(
967             "from __future__ import unicode_literals as _lol, print"
968         )
969         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
970
971     def test_debug_visitor(self) -> None:
972         source, _ = read_data("debug_visitor.py")
973         expected, _ = read_data("debug_visitor.out")
974         out_lines = []
975         err_lines = []
976
977         def out(msg: str, **kwargs: Any) -> None:
978             out_lines.append(msg)
979
980         def err(msg: str, **kwargs: Any) -> None:
981             err_lines.append(msg)
982
983         with patch("black.out", out), patch("black.err", err):
984             black.DebugVisitor.show(source)
985         actual = "\n".join(out_lines) + "\n"
986         log_name = ""
987         if expected != actual:
988             log_name = black.dump_to_file(*out_lines)
989         self.assertEqual(
990             expected,
991             actual,
992             f"AST print out is different. Actual version dumped to {log_name}",
993         )
994
995     def test_format_file_contents(self) -> None:
996         empty = ""
997         mode = black.FileMode()
998         with self.assertRaises(black.NothingChanged):
999             black.format_file_contents(empty, mode=mode, fast=False)
1000         just_nl = "\n"
1001         with self.assertRaises(black.NothingChanged):
1002             black.format_file_contents(just_nl, mode=mode, fast=False)
1003         same = "l = [1, 2, 3]\n"
1004         with self.assertRaises(black.NothingChanged):
1005             black.format_file_contents(same, mode=mode, fast=False)
1006         different = "l = [1,2,3]"
1007         expected = same
1008         actual = black.format_file_contents(different, mode=mode, fast=False)
1009         self.assertEqual(expected, actual)
1010         invalid = "return if you can"
1011         with self.assertRaises(black.InvalidInput) as e:
1012             black.format_file_contents(invalid, mode=mode, fast=False)
1013         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
1014
1015     def test_endmarker(self) -> None:
1016         n = black.lib2to3_parse("\n")
1017         self.assertEqual(n.type, black.syms.file_input)
1018         self.assertEqual(len(n.children), 1)
1019         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
1020
1021     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
1022     def test_assertFormatEqual(self) -> None:
1023         out_lines = []
1024         err_lines = []
1025
1026         def out(msg: str, **kwargs: Any) -> None:
1027             out_lines.append(msg)
1028
1029         def err(msg: str, **kwargs: Any) -> None:
1030             err_lines.append(msg)
1031
1032         with patch("black.out", out), patch("black.err", err):
1033             with self.assertRaises(AssertionError):
1034                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
1035
1036         out_str = "".join(out_lines)
1037         self.assertTrue("Expected tree:" in out_str)
1038         self.assertTrue("Actual tree:" in out_str)
1039         self.assertEqual("".join(err_lines), "")
1040
1041     def test_cache_broken_file(self) -> None:
1042         mode = black.FileMode()
1043         with cache_dir() as workspace:
1044             cache_file = black.get_cache_file(mode)
1045             with cache_file.open("w") as fobj:
1046                 fobj.write("this is not a pickle")
1047             self.assertEqual(black.read_cache(mode), {})
1048             src = (workspace / "test.py").resolve()
1049             with src.open("w") as fobj:
1050                 fobj.write("print('hello')")
1051             self.invokeBlack([str(src)])
1052             cache = black.read_cache(mode)
1053             self.assertIn(src, cache)
1054
1055     def test_cache_single_file_already_cached(self) -> None:
1056         mode = black.FileMode()
1057         with cache_dir() as workspace:
1058             src = (workspace / "test.py").resolve()
1059             with src.open("w") as fobj:
1060                 fobj.write("print('hello')")
1061             black.write_cache({}, [src], mode)
1062             self.invokeBlack([str(src)])
1063             with src.open("r") as fobj:
1064                 self.assertEqual(fobj.read(), "print('hello')")
1065
1066     @event_loop(close=False)
1067     def test_cache_multiple_files(self) -> None:
1068         mode = black.FileMode()
1069         with cache_dir() as workspace, patch(
1070             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1071         ):
1072             one = (workspace / "one.py").resolve()
1073             with one.open("w") as fobj:
1074                 fobj.write("print('hello')")
1075             two = (workspace / "two.py").resolve()
1076             with two.open("w") as fobj:
1077                 fobj.write("print('hello')")
1078             black.write_cache({}, [one], mode)
1079             self.invokeBlack([str(workspace)])
1080             with one.open("r") as fobj:
1081                 self.assertEqual(fobj.read(), "print('hello')")
1082             with two.open("r") as fobj:
1083                 self.assertEqual(fobj.read(), 'print("hello")\n')
1084             cache = black.read_cache(mode)
1085             self.assertIn(one, cache)
1086             self.assertIn(two, cache)
1087
1088     def test_no_cache_when_writeback_diff(self) -> None:
1089         mode = black.FileMode()
1090         with cache_dir() as workspace:
1091             src = (workspace / "test.py").resolve()
1092             with src.open("w") as fobj:
1093                 fobj.write("print('hello')")
1094             self.invokeBlack([str(src), "--diff"])
1095             cache_file = black.get_cache_file(mode)
1096             self.assertFalse(cache_file.exists())
1097
1098     def test_no_cache_when_stdin(self) -> None:
1099         mode = black.FileMode()
1100         with cache_dir():
1101             result = CliRunner().invoke(
1102                 black.main, ["-"], input=BytesIO(b"print('hello')")
1103             )
1104             self.assertEqual(result.exit_code, 0)
1105             cache_file = black.get_cache_file(mode)
1106             self.assertFalse(cache_file.exists())
1107
1108     def test_read_cache_no_cachefile(self) -> None:
1109         mode = black.FileMode()
1110         with cache_dir():
1111             self.assertEqual(black.read_cache(mode), {})
1112
1113     def test_write_cache_read_cache(self) -> None:
1114         mode = black.FileMode()
1115         with cache_dir() as workspace:
1116             src = (workspace / "test.py").resolve()
1117             src.touch()
1118             black.write_cache({}, [src], mode)
1119             cache = black.read_cache(mode)
1120             self.assertIn(src, cache)
1121             self.assertEqual(cache[src], black.get_cache_info(src))
1122
1123     def test_filter_cached(self) -> None:
1124         with TemporaryDirectory() as workspace:
1125             path = Path(workspace)
1126             uncached = (path / "uncached").resolve()
1127             cached = (path / "cached").resolve()
1128             cached_but_changed = (path / "changed").resolve()
1129             uncached.touch()
1130             cached.touch()
1131             cached_but_changed.touch()
1132             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1133             todo, done = black.filter_cached(
1134                 cache, {uncached, cached, cached_but_changed}
1135             )
1136             self.assertEqual(todo, {uncached, cached_but_changed})
1137             self.assertEqual(done, {cached})
1138
1139     def test_write_cache_creates_directory_if_needed(self) -> None:
1140         mode = black.FileMode()
1141         with cache_dir(exists=False) as workspace:
1142             self.assertFalse(workspace.exists())
1143             black.write_cache({}, [], mode)
1144             self.assertTrue(workspace.exists())
1145
1146     @event_loop(close=False)
1147     def test_failed_formatting_does_not_get_cached(self) -> None:
1148         mode = black.FileMode()
1149         with cache_dir() as workspace, patch(
1150             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1151         ):
1152             failing = (workspace / "failing.py").resolve()
1153             with failing.open("w") as fobj:
1154                 fobj.write("not actually python")
1155             clean = (workspace / "clean.py").resolve()
1156             with clean.open("w") as fobj:
1157                 fobj.write('print("hello")\n')
1158             self.invokeBlack([str(workspace)], exit_code=123)
1159             cache = black.read_cache(mode)
1160             self.assertNotIn(failing, cache)
1161             self.assertIn(clean, cache)
1162
1163     def test_write_cache_write_fail(self) -> None:
1164         mode = black.FileMode()
1165         with cache_dir(), patch.object(Path, "open") as mock:
1166             mock.side_effect = OSError
1167             black.write_cache({}, [], mode)
1168
1169     @event_loop(close=False)
1170     def test_check_diff_use_together(self) -> None:
1171         with cache_dir():
1172             # Files which will be reformatted.
1173             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1174             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1175             # Files which will not be reformatted.
1176             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1177             self.invokeBlack([str(src2), "--diff", "--check"])
1178             # Multi file command.
1179             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1180
1181     def test_no_files(self) -> None:
1182         with cache_dir():
1183             # Without an argument, black exits with error code 0.
1184             self.invokeBlack([])
1185
1186     def test_broken_symlink(self) -> None:
1187         with cache_dir() as workspace:
1188             symlink = workspace / "broken_link.py"
1189             try:
1190                 symlink.symlink_to("nonexistent.py")
1191             except OSError as e:
1192                 self.skipTest(f"Can't create symlinks: {e}")
1193             self.invokeBlack([str(workspace.resolve())])
1194
1195     def test_read_cache_line_lengths(self) -> None:
1196         mode = black.FileMode()
1197         short_mode = black.FileMode(line_length=1)
1198         with cache_dir() as workspace:
1199             path = (workspace / "file.py").resolve()
1200             path.touch()
1201             black.write_cache({}, [path], mode)
1202             one = black.read_cache(mode)
1203             self.assertIn(path, one)
1204             two = black.read_cache(short_mode)
1205             self.assertNotIn(path, two)
1206
1207     def test_single_file_force_pyi(self) -> None:
1208         reg_mode = black.FileMode()
1209         pyi_mode = black.FileMode(is_pyi=True)
1210         contents, expected = read_data("force_pyi")
1211         with cache_dir() as workspace:
1212             path = (workspace / "file.py").resolve()
1213             with open(path, "w") as fh:
1214                 fh.write(contents)
1215             self.invokeBlack([str(path), "--pyi"])
1216             with open(path, "r") as fh:
1217                 actual = fh.read()
1218             # verify cache with --pyi is separate
1219             pyi_cache = black.read_cache(pyi_mode)
1220             self.assertIn(path, pyi_cache)
1221             normal_cache = black.read_cache(reg_mode)
1222             self.assertNotIn(path, normal_cache)
1223         self.assertEqual(actual, expected)
1224
1225     @event_loop(close=False)
1226     def test_multi_file_force_pyi(self) -> None:
1227         reg_mode = black.FileMode()
1228         pyi_mode = black.FileMode(is_pyi=True)
1229         contents, expected = read_data("force_pyi")
1230         with cache_dir() as workspace:
1231             paths = [
1232                 (workspace / "file1.py").resolve(),
1233                 (workspace / "file2.py").resolve(),
1234             ]
1235             for path in paths:
1236                 with open(path, "w") as fh:
1237                     fh.write(contents)
1238             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1239             for path in paths:
1240                 with open(path, "r") as fh:
1241                     actual = fh.read()
1242                 self.assertEqual(actual, expected)
1243             # verify cache with --pyi is separate
1244             pyi_cache = black.read_cache(pyi_mode)
1245             normal_cache = black.read_cache(reg_mode)
1246             for path in paths:
1247                 self.assertIn(path, pyi_cache)
1248                 self.assertNotIn(path, normal_cache)
1249
1250     def test_pipe_force_pyi(self) -> None:
1251         source, expected = read_data("force_pyi")
1252         result = CliRunner().invoke(
1253             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1254         )
1255         self.assertEqual(result.exit_code, 0)
1256         actual = result.output
1257         self.assertFormatEqual(actual, expected)
1258
1259     def test_single_file_force_py36(self) -> None:
1260         reg_mode = black.FileMode()
1261         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1262         source, expected = read_data("force_py36")
1263         with cache_dir() as workspace:
1264             path = (workspace / "file.py").resolve()
1265             with open(path, "w") as fh:
1266                 fh.write(source)
1267             self.invokeBlack([str(path), *PY36_ARGS])
1268             with open(path, "r") as fh:
1269                 actual = fh.read()
1270             # verify cache with --target-version is separate
1271             py36_cache = black.read_cache(py36_mode)
1272             self.assertIn(path, py36_cache)
1273             normal_cache = black.read_cache(reg_mode)
1274             self.assertNotIn(path, normal_cache)
1275         self.assertEqual(actual, expected)
1276
1277     @event_loop(close=False)
1278     def test_multi_file_force_py36(self) -> None:
1279         reg_mode = black.FileMode()
1280         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1281         source, expected = read_data("force_py36")
1282         with cache_dir() as workspace:
1283             paths = [
1284                 (workspace / "file1.py").resolve(),
1285                 (workspace / "file2.py").resolve(),
1286             ]
1287             for path in paths:
1288                 with open(path, "w") as fh:
1289                     fh.write(source)
1290             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1291             for path in paths:
1292                 with open(path, "r") as fh:
1293                     actual = fh.read()
1294                 self.assertEqual(actual, expected)
1295             # verify cache with --target-version is separate
1296             pyi_cache = black.read_cache(py36_mode)
1297             normal_cache = black.read_cache(reg_mode)
1298             for path in paths:
1299                 self.assertIn(path, pyi_cache)
1300                 self.assertNotIn(path, normal_cache)
1301
1302     def test_pipe_force_py36(self) -> None:
1303         source, expected = read_data("force_py36")
1304         result = CliRunner().invoke(
1305             black.main,
1306             ["-", "-q", "--target-version=py36"],
1307             input=BytesIO(source.encode("utf8")),
1308         )
1309         self.assertEqual(result.exit_code, 0)
1310         actual = result.output
1311         self.assertFormatEqual(actual, expected)
1312
1313     def test_include_exclude(self) -> None:
1314         path = THIS_DIR / "data" / "include_exclude_tests"
1315         include = re.compile(r"\.pyi?$")
1316         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1317         report = black.Report()
1318         sources: List[Path] = []
1319         expected = [
1320             Path(path / "b/dont_exclude/a.py"),
1321             Path(path / "b/dont_exclude/a.pyi"),
1322         ]
1323         this_abs = THIS_DIR.resolve()
1324         sources.extend(
1325             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1326         )
1327         self.assertEqual(sorted(expected), sorted(sources))
1328
1329     def test_empty_include(self) -> None:
1330         path = THIS_DIR / "data" / "include_exclude_tests"
1331         report = black.Report()
1332         empty = re.compile(r"")
1333         sources: List[Path] = []
1334         expected = [
1335             Path(path / "b/exclude/a.pie"),
1336             Path(path / "b/exclude/a.py"),
1337             Path(path / "b/exclude/a.pyi"),
1338             Path(path / "b/dont_exclude/a.pie"),
1339             Path(path / "b/dont_exclude/a.py"),
1340             Path(path / "b/dont_exclude/a.pyi"),
1341             Path(path / "b/.definitely_exclude/a.pie"),
1342             Path(path / "b/.definitely_exclude/a.py"),
1343             Path(path / "b/.definitely_exclude/a.pyi"),
1344         ]
1345         this_abs = THIS_DIR.resolve()
1346         sources.extend(
1347             black.gen_python_files_in_dir(
1348                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1349             )
1350         )
1351         self.assertEqual(sorted(expected), sorted(sources))
1352
1353     def test_empty_exclude(self) -> None:
1354         path = THIS_DIR / "data" / "include_exclude_tests"
1355         report = black.Report()
1356         empty = re.compile(r"")
1357         sources: List[Path] = []
1358         expected = [
1359             Path(path / "b/dont_exclude/a.py"),
1360             Path(path / "b/dont_exclude/a.pyi"),
1361             Path(path / "b/exclude/a.py"),
1362             Path(path / "b/exclude/a.pyi"),
1363             Path(path / "b/.definitely_exclude/a.py"),
1364             Path(path / "b/.definitely_exclude/a.pyi"),
1365         ]
1366         this_abs = THIS_DIR.resolve()
1367         sources.extend(
1368             black.gen_python_files_in_dir(
1369                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1370             )
1371         )
1372         self.assertEqual(sorted(expected), sorted(sources))
1373
1374     def test_invalid_include_exclude(self) -> None:
1375         for option in ["--include", "--exclude"]:
1376             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1377
1378     def test_preserves_line_endings(self) -> None:
1379         with TemporaryDirectory() as workspace:
1380             test_file = Path(workspace) / "test.py"
1381             for nl in ["\n", "\r\n"]:
1382                 contents = nl.join(["def f(  ):", "    pass"])
1383                 test_file.write_bytes(contents.encode())
1384                 ff(test_file, write_back=black.WriteBack.YES)
1385                 updated_contents: bytes = test_file.read_bytes()
1386                 self.assertIn(nl.encode(), updated_contents)
1387                 if nl == "\n":
1388                     self.assertNotIn(b"\r\n", updated_contents)
1389
1390     def test_preserves_line_endings_via_stdin(self) -> None:
1391         for nl in ["\n", "\r\n"]:
1392             contents = nl.join(["def f(  ):", "    pass"])
1393             runner = BlackRunner()
1394             result = runner.invoke(
1395                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1396             )
1397             self.assertEqual(result.exit_code, 0)
1398             output = runner.stdout_bytes
1399             self.assertIn(nl.encode("utf8"), output)
1400             if nl == "\n":
1401                 self.assertNotIn(b"\r\n", output)
1402
1403     def test_assert_equivalent_different_asts(self) -> None:
1404         with self.assertRaises(AssertionError):
1405             black.assert_equivalent("{}", "None")
1406
1407     def test_symlink_out_of_root_directory(self) -> None:
1408         path = MagicMock()
1409         root = THIS_DIR
1410         child = MagicMock()
1411         include = re.compile(black.DEFAULT_INCLUDES)
1412         exclude = re.compile(black.DEFAULT_EXCLUDES)
1413         report = black.Report()
1414         # `child` should behave like a symlink which resolved path is clearly
1415         # outside of the `root` directory.
1416         path.iterdir.return_value = [child]
1417         child.resolve.return_value = Path("/a/b/c")
1418         child.is_symlink.return_value = True
1419         try:
1420             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1421         except ValueError as ve:
1422             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1423         path.iterdir.assert_called_once()
1424         child.resolve.assert_called_once()
1425         child.is_symlink.assert_called_once()
1426         # `child` should behave like a strange file which resolved path is clearly
1427         # outside of the `root` directory.
1428         child.is_symlink.return_value = False
1429         with self.assertRaises(ValueError):
1430             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1431         path.iterdir.assert_called()
1432         self.assertEqual(path.iterdir.call_count, 2)
1433         child.resolve.assert_called()
1434         self.assertEqual(child.resolve.call_count, 2)
1435         child.is_symlink.assert_called()
1436         self.assertEqual(child.is_symlink.call_count, 2)
1437
1438     def test_shhh_click(self) -> None:
1439         try:
1440             from click import _unicodefun  # type: ignore
1441         except ModuleNotFoundError:
1442             self.skipTest("Incompatible Click version")
1443         if not hasattr(_unicodefun, "_verify_python3_env"):
1444             self.skipTest("Incompatible Click version")
1445         # First, let's see if Click is crashing with a preferred ASCII charset.
1446         with patch("locale.getpreferredencoding") as gpe:
1447             gpe.return_value = "ASCII"
1448             with self.assertRaises(RuntimeError):
1449                 _unicodefun._verify_python3_env()
1450         # Now, let's silence Click...
1451         black.patch_click()
1452         # ...and confirm it's silent.
1453         with patch("locale.getpreferredencoding") as gpe:
1454             gpe.return_value = "ASCII"
1455             try:
1456                 _unicodefun._verify_python3_env()
1457             except RuntimeError as re:
1458                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1459
1460     def test_root_logger_not_used_directly(self) -> None:
1461         def fail(*args: Any, **kwargs: Any) -> None:
1462             self.fail("Record created with root logger")
1463
1464         with patch.multiple(
1465             logging.root,
1466             debug=fail,
1467             info=fail,
1468             warning=fail,
1469             error=fail,
1470             critical=fail,
1471             log=fail,
1472         ):
1473             ff(THIS_FILE)
1474
1475     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1476     @async_test
1477     async def test_blackd_request_needs_formatting(self) -> None:
1478         app = blackd.make_app()
1479         async with TestClient(TestServer(app)) as client:
1480             response = await client.post("/", data=b"print('hello world')")
1481             self.assertEqual(response.status, 200)
1482             self.assertEqual(response.charset, "utf8")
1483             self.assertEqual(await response.read(), b'print("hello world")\n')
1484
1485     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1486     @async_test
1487     async def test_blackd_request_no_change(self) -> None:
1488         app = blackd.make_app()
1489         async with TestClient(TestServer(app)) as client:
1490             response = await client.post("/", data=b'print("hello world")\n')
1491             self.assertEqual(response.status, 204)
1492             self.assertEqual(await response.read(), b"")
1493
1494     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1495     @async_test
1496     async def test_blackd_request_syntax_error(self) -> None:
1497         app = blackd.make_app()
1498         async with TestClient(TestServer(app)) as client:
1499             response = await client.post("/", data=b"what even ( is")
1500             self.assertEqual(response.status, 400)
1501             content = await response.text()
1502             self.assertTrue(
1503                 content.startswith("Cannot parse"),
1504                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1505             )
1506
1507     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1508     @async_test
1509     async def test_blackd_unsupported_version(self) -> None:
1510         app = blackd.make_app()
1511         async with TestClient(TestServer(app)) as client:
1512             response = await client.post(
1513                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1514             )
1515             self.assertEqual(response.status, 501)
1516
1517     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1518     @async_test
1519     async def test_blackd_supported_version(self) -> None:
1520         app = blackd.make_app()
1521         async with TestClient(TestServer(app)) as client:
1522             response = await client.post(
1523                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1524             )
1525             self.assertEqual(response.status, 200)
1526
1527     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1528     @async_test
1529     async def test_blackd_invalid_python_variant(self) -> None:
1530         app = blackd.make_app()
1531         async with TestClient(TestServer(app)) as client:
1532
1533             async def check(header_value: str, expected_status: int = 400) -> None:
1534                 response = await client.post(
1535                     "/",
1536                     data=b"what",
1537                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1538                 )
1539                 self.assertEqual(response.status, expected_status)
1540
1541             await check("lol")
1542             await check("ruby3.5")
1543             await check("pyi3.6")
1544             await check("py1.5")
1545             await check("2.8")
1546             await check("py2.8")
1547             await check("3.0")
1548             await check("pypy3.0")
1549             await check("jython3.4")
1550
1551     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1552     @async_test
1553     async def test_blackd_pyi(self) -> None:
1554         app = blackd.make_app()
1555         async with TestClient(TestServer(app)) as client:
1556             source, expected = read_data("stub.pyi")
1557             response = await client.post(
1558                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1559             )
1560             self.assertEqual(response.status, 200)
1561             self.assertEqual(await response.text(), expected)
1562
1563     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1564     @async_test
1565     async def test_blackd_python_variant(self) -> None:
1566         app = blackd.make_app()
1567         code = (
1568             "def f(\n"
1569             "    and_has_a_bunch_of,\n"
1570             "    very_long_arguments_too,\n"
1571             "    and_lots_of_them_as_well_lol,\n"
1572             "    **and_very_long_keyword_arguments\n"
1573             "):\n"
1574             "    pass\n"
1575         )
1576         async with TestClient(TestServer(app)) as client:
1577
1578             async def check(header_value: str, expected_status: int) -> None:
1579                 response = await client.post(
1580                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1581                 )
1582                 self.assertEqual(response.status, expected_status)
1583
1584             await check("3.6", 200)
1585             await check("py3.6", 200)
1586             await check("3.6,3.7", 200)
1587             await check("3.6,py3.7", 200)
1588
1589             await check("2", 204)
1590             await check("2.7", 204)
1591             await check("py2.7", 204)
1592             await check("3.4", 204)
1593             await check("py3.4", 204)
1594
1595     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1596     @async_test
1597     async def test_blackd_line_length(self) -> None:
1598         app = blackd.make_app()
1599         async with TestClient(TestServer(app)) as client:
1600             response = await client.post(
1601                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1602             )
1603             self.assertEqual(response.status, 200)
1604
1605     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1606     @async_test
1607     async def test_blackd_invalid_line_length(self) -> None:
1608         app = blackd.make_app()
1609         async with TestClient(TestServer(app)) as client:
1610             response = await client.post(
1611                 "/",
1612                 data=b'print("hello")\n',
1613                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1614             )
1615             self.assertEqual(response.status, 400)
1616
1617     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1618     def test_blackd_main(self) -> None:
1619         with patch("blackd.web.run_app"):
1620             result = CliRunner().invoke(blackd.main, [])
1621             if result.exception is not None:
1622                 raise result.exception
1623             self.assertEqual(result.exit_code, 0)
1624
1625
1626 if __name__ == "__main__":
1627     unittest.main(module="test_black")