]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Terget version option kebab-style (#770)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_cantfit(self) -> None:
393         source, expected = read_data("cantfit")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_import_spacing(self) -> None:
401         source, expected = read_data("import_spacing")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_composition(self) -> None:
409         source, expected = read_data("composition")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_empty_lines(self) -> None:
417         source, expected = read_data("empty_lines")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_string_prefixes(self) -> None:
425         source, expected = read_data("string_prefixes")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_numeric_literals(self) -> None:
433         source, expected = read_data("numeric_literals")
434         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
435         actual = fs(source, mode=mode)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, mode)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_numeric_literals_ignoring_underscores(self) -> None:
442         source, expected = read_data("numeric_literals_skip_underscores")
443         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
444         actual = fs(source, mode=mode)
445         self.assertFormatEqual(expected, actual)
446         black.assert_equivalent(source, actual)
447         black.assert_stable(source, actual, mode)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_numeric_literals_py2(self) -> None:
451         source, expected = read_data("numeric_literals_py2")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2(self) -> None:
458         source, expected = read_data("python2")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         # black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_python2_print_function(self) -> None:
466         source, expected = read_data("python2_print_function")
467         mode = black.FileMode(target_versions={TargetVersion.PY27})
468         actual = fs(source, mode=mode)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, mode)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2_unicode_literals(self) -> None:
474         source, expected = read_data("python2_unicode_literals")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_stable(source, actual, black.FileMode())
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_stub(self) -> None:
481         mode = black.FileMode(is_pyi=True)
482         source, expected = read_data("stub.pyi")
483         actual = fs(source, mode=mode)
484         self.assertFormatEqual(expected, actual)
485         black.assert_stable(source, actual, mode)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_python37(self) -> None:
489         source, expected = read_data("python37")
490         actual = fs(source)
491         self.assertFormatEqual(expected, actual)
492         major, minor = sys.version_info[:2]
493         if major > 3 or (major == 3 and minor >= 7):
494             black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_fmtonoff(self) -> None:
499         source, expected = read_data("fmtonoff")
500         actual = fs(source)
501         self.assertFormatEqual(expected, actual)
502         black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_fmtonoff2(self) -> None:
507         source, expected = read_data("fmtonoff2")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_remove_empty_parentheses_after_class(self) -> None:
515         source, expected = read_data("class_blank_parentheses")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_new_line_between_class_and_code(self) -> None:
523         source, expected = read_data("class_methods_new_line")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_bracket_match(self) -> None:
531         source, expected = read_data("bracketmatch")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, black.FileMode())
536
537     def test_tab_comment_indentation(self) -> None:
538         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
539         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
540         self.assertFormatEqual(contents_spc, fs(contents_spc))
541         self.assertFormatEqual(contents_spc, fs(contents_tab))
542
543         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
544         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
545         self.assertFormatEqual(contents_spc, fs(contents_spc))
546         self.assertFormatEqual(contents_spc, fs(contents_tab))
547
548         # mixed tabs and spaces (valid Python 2 code)
549         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
550         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
551         self.assertFormatEqual(contents_spc, fs(contents_spc))
552         self.assertFormatEqual(contents_spc, fs(contents_tab))
553
554         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
555         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
556         self.assertFormatEqual(contents_spc, fs(contents_spc))
557         self.assertFormatEqual(contents_spc, fs(contents_tab))
558
559     def test_report_verbose(self) -> None:
560         report = black.Report(verbose=True)
561         out_lines = []
562         err_lines = []
563
564         def out(msg: str, **kwargs: Any) -> None:
565             out_lines.append(msg)
566
567         def err(msg: str, **kwargs: Any) -> None:
568             err_lines.append(msg)
569
570         with patch("black.out", out), patch("black.err", err):
571             report.done(Path("f1"), black.Changed.NO)
572             self.assertEqual(len(out_lines), 1)
573             self.assertEqual(len(err_lines), 0)
574             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
575             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
576             self.assertEqual(report.return_code, 0)
577             report.done(Path("f2"), black.Changed.YES)
578             self.assertEqual(len(out_lines), 2)
579             self.assertEqual(len(err_lines), 0)
580             self.assertEqual(out_lines[-1], "reformatted f2")
581             self.assertEqual(
582                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
583             )
584             report.done(Path("f3"), black.Changed.CACHED)
585             self.assertEqual(len(out_lines), 3)
586             self.assertEqual(len(err_lines), 0)
587             self.assertEqual(
588                 out_lines[-1], "f3 wasn't modified on disk since last run."
589             )
590             self.assertEqual(
591                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
592             )
593             self.assertEqual(report.return_code, 0)
594             report.check = True
595             self.assertEqual(report.return_code, 1)
596             report.check = False
597             report.failed(Path("e1"), "boom")
598             self.assertEqual(len(out_lines), 3)
599             self.assertEqual(len(err_lines), 1)
600             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "1 file reformatted, 2 files left unchanged, "
604                 "1 file failed to reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.done(Path("f3"), black.Changed.YES)
608             self.assertEqual(len(out_lines), 4)
609             self.assertEqual(len(err_lines), 1)
610             self.assertEqual(out_lines[-1], "reformatted f3")
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 2 files left unchanged, "
614                 "1 file failed to reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.failed(Path("e2"), "boom")
618             self.assertEqual(len(out_lines), 4)
619             self.assertEqual(len(err_lines), 2)
620             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
621             self.assertEqual(
622                 unstyle(str(report)),
623                 "2 files reformatted, 2 files left unchanged, "
624                 "2 files failed to reformat.",
625             )
626             self.assertEqual(report.return_code, 123)
627             report.path_ignored(Path("wat"), "no match")
628             self.assertEqual(len(out_lines), 5)
629             self.assertEqual(len(err_lines), 2)
630             self.assertEqual(out_lines[-1], "wat ignored: no match")
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files reformatted, 2 files left unchanged, "
634                 "2 files failed to reformat.",
635             )
636             self.assertEqual(report.return_code, 123)
637             report.done(Path("f4"), black.Changed.NO)
638             self.assertEqual(len(out_lines), 6)
639             self.assertEqual(len(err_lines), 2)
640             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files reformatted, 3 files left unchanged, "
644                 "2 files failed to reformat.",
645             )
646             self.assertEqual(report.return_code, 123)
647             report.check = True
648             self.assertEqual(
649                 unstyle(str(report)),
650                 "2 files would be reformatted, 3 files would be left unchanged, "
651                 "2 files would fail to reformat.",
652             )
653
654     def test_report_quiet(self) -> None:
655         report = black.Report(quiet=True)
656         out_lines = []
657         err_lines = []
658
659         def out(msg: str, **kwargs: Any) -> None:
660             out_lines.append(msg)
661
662         def err(msg: str, **kwargs: Any) -> None:
663             err_lines.append(msg)
664
665         with patch("black.out", out), patch("black.err", err):
666             report.done(Path("f1"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 0)
668             self.assertEqual(len(err_lines), 0)
669             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
670             self.assertEqual(report.return_code, 0)
671             report.done(Path("f2"), black.Changed.YES)
672             self.assertEqual(len(out_lines), 0)
673             self.assertEqual(len(err_lines), 0)
674             self.assertEqual(
675                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
676             )
677             report.done(Path("f3"), black.Changed.CACHED)
678             self.assertEqual(len(out_lines), 0)
679             self.assertEqual(len(err_lines), 0)
680             self.assertEqual(
681                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
682             )
683             self.assertEqual(report.return_code, 0)
684             report.check = True
685             self.assertEqual(report.return_code, 1)
686             report.check = False
687             report.failed(Path("e1"), "boom")
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 1)
690             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
691             self.assertEqual(
692                 unstyle(str(report)),
693                 "1 file reformatted, 2 files left unchanged, "
694                 "1 file failed to reformat.",
695             )
696             self.assertEqual(report.return_code, 123)
697             report.done(Path("f3"), black.Changed.YES)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 1)
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 2 files left unchanged, "
703                 "1 file failed to reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.failed(Path("e2"), "boom")
707             self.assertEqual(len(out_lines), 0)
708             self.assertEqual(len(err_lines), 2)
709             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files reformatted, 2 files left unchanged, "
713                 "2 files failed to reformat.",
714             )
715             self.assertEqual(report.return_code, 123)
716             report.path_ignored(Path("wat"), "no match")
717             self.assertEqual(len(out_lines), 0)
718             self.assertEqual(len(err_lines), 2)
719             self.assertEqual(
720                 unstyle(str(report)),
721                 "2 files reformatted, 2 files left unchanged, "
722                 "2 files failed to reformat.",
723             )
724             self.assertEqual(report.return_code, 123)
725             report.done(Path("f4"), black.Changed.NO)
726             self.assertEqual(len(out_lines), 0)
727             self.assertEqual(len(err_lines), 2)
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 3 files left unchanged, "
731                 "2 files failed to reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.check = True
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files would be reformatted, 3 files would be left unchanged, "
738                 "2 files would fail to reformat.",
739             )
740
741     def test_report_normal(self) -> None:
742         report = black.Report()
743         out_lines = []
744         err_lines = []
745
746         def out(msg: str, **kwargs: Any) -> None:
747             out_lines.append(msg)
748
749         def err(msg: str, **kwargs: Any) -> None:
750             err_lines.append(msg)
751
752         with patch("black.out", out), patch("black.err", err):
753             report.done(Path("f1"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
757             self.assertEqual(report.return_code, 0)
758             report.done(Path("f2"), black.Changed.YES)
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 0)
761             self.assertEqual(out_lines[-1], "reformatted f2")
762             self.assertEqual(
763                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
764             )
765             report.done(Path("f3"), black.Changed.CACHED)
766             self.assertEqual(len(out_lines), 1)
767             self.assertEqual(len(err_lines), 0)
768             self.assertEqual(out_lines[-1], "reformatted f2")
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
771             )
772             self.assertEqual(report.return_code, 0)
773             report.check = True
774             self.assertEqual(report.return_code, 1)
775             report.check = False
776             report.failed(Path("e1"), "boom")
777             self.assertEqual(len(out_lines), 1)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "1 file reformatted, 2 files left unchanged, "
783                 "1 file failed to reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.done(Path("f3"), black.Changed.YES)
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 1)
789             self.assertEqual(out_lines[-1], "reformatted f3")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, "
793                 "1 file failed to reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.failed(Path("e2"), "boom")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 2 files left unchanged, "
803                 "2 files failed to reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.path_ignored(Path("wat"), "no match")
807             self.assertEqual(len(out_lines), 2)
808             self.assertEqual(len(err_lines), 2)
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, "
812                 "2 files failed to reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.done(Path("f4"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 3 files left unchanged, "
821                 "2 files failed to reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.check = True
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files would be reformatted, 3 files would be left unchanged, "
828                 "2 files would fail to reformat.",
829             )
830
831     def test_lib2to3_parse(self) -> None:
832         with self.assertRaises(black.InvalidInput):
833             black.lib2to3_parse("invalid syntax")
834
835         straddling = "x + y"
836         black.lib2to3_parse(straddling)
837         black.lib2to3_parse(straddling, {TargetVersion.PY27})
838         black.lib2to3_parse(straddling, {TargetVersion.PY36})
839         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
840
841         py2_only = "print x"
842         black.lib2to3_parse(py2_only)
843         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
844         with self.assertRaises(black.InvalidInput):
845             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
846         with self.assertRaises(black.InvalidInput):
847             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
848
849         py3_only = "exec(x, end=y)"
850         black.lib2to3_parse(py3_only)
851         with self.assertRaises(black.InvalidInput):
852             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
853         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
854         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
855
856     def test_get_features_used(self) -> None:
857         node = black.lib2to3_parse("def f(*, arg): ...\n")
858         self.assertEqual(black.get_features_used(node), set())
859         node = black.lib2to3_parse("def f(*, arg,): ...\n")
860         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
861         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
862         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
863         node = black.lib2to3_parse("123_456\n")
864         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
865         node = black.lib2to3_parse("123456\n")
866         self.assertEqual(black.get_features_used(node), set())
867         source, expected = read_data("function")
868         node = black.lib2to3_parse(source)
869         self.assertEqual(
870             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
871         )
872         node = black.lib2to3_parse(expected)
873         self.assertEqual(
874             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
875         )
876         source, expected = read_data("expression")
877         node = black.lib2to3_parse(source)
878         self.assertEqual(black.get_features_used(node), set())
879         node = black.lib2to3_parse(expected)
880         self.assertEqual(black.get_features_used(node), set())
881
882     def test_get_future_imports(self) -> None:
883         node = black.lib2to3_parse("\n")
884         self.assertEqual(set(), black.get_future_imports(node))
885         node = black.lib2to3_parse("from __future__ import black\n")
886         self.assertEqual({"black"}, black.get_future_imports(node))
887         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
888         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
889         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
890         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
891         node = black.lib2to3_parse(
892             "from __future__ import multiple\nfrom __future__ import imports\n"
893         )
894         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
895         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
896         self.assertEqual({"black"}, black.get_future_imports(node))
897         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
898         self.assertEqual({"black"}, black.get_future_imports(node))
899         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
900         self.assertEqual(set(), black.get_future_imports(node))
901         node = black.lib2to3_parse("from some.module import black\n")
902         self.assertEqual(set(), black.get_future_imports(node))
903         node = black.lib2to3_parse(
904             "from __future__ import unicode_literals as _unicode_literals"
905         )
906         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
907         node = black.lib2to3_parse(
908             "from __future__ import unicode_literals as _lol, print"
909         )
910         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
911
912     def test_debug_visitor(self) -> None:
913         source, _ = read_data("debug_visitor.py")
914         expected, _ = read_data("debug_visitor.out")
915         out_lines = []
916         err_lines = []
917
918         def out(msg: str, **kwargs: Any) -> None:
919             out_lines.append(msg)
920
921         def err(msg: str, **kwargs: Any) -> None:
922             err_lines.append(msg)
923
924         with patch("black.out", out), patch("black.err", err):
925             black.DebugVisitor.show(source)
926         actual = "\n".join(out_lines) + "\n"
927         log_name = ""
928         if expected != actual:
929             log_name = black.dump_to_file(*out_lines)
930         self.assertEqual(
931             expected,
932             actual,
933             f"AST print out is different. Actual version dumped to {log_name}",
934         )
935
936     def test_format_file_contents(self) -> None:
937         empty = ""
938         mode = black.FileMode()
939         with self.assertRaises(black.NothingChanged):
940             black.format_file_contents(empty, mode=mode, fast=False)
941         just_nl = "\n"
942         with self.assertRaises(black.NothingChanged):
943             black.format_file_contents(just_nl, mode=mode, fast=False)
944         same = "l = [1, 2, 3]\n"
945         with self.assertRaises(black.NothingChanged):
946             black.format_file_contents(same, mode=mode, fast=False)
947         different = "l = [1,2,3]"
948         expected = same
949         actual = black.format_file_contents(different, mode=mode, fast=False)
950         self.assertEqual(expected, actual)
951         invalid = "return if you can"
952         with self.assertRaises(black.InvalidInput) as e:
953             black.format_file_contents(invalid, mode=mode, fast=False)
954         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
955
956     def test_endmarker(self) -> None:
957         n = black.lib2to3_parse("\n")
958         self.assertEqual(n.type, black.syms.file_input)
959         self.assertEqual(len(n.children), 1)
960         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
961
962     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
963     def test_assertFormatEqual(self) -> None:
964         out_lines = []
965         err_lines = []
966
967         def out(msg: str, **kwargs: Any) -> None:
968             out_lines.append(msg)
969
970         def err(msg: str, **kwargs: Any) -> None:
971             err_lines.append(msg)
972
973         with patch("black.out", out), patch("black.err", err):
974             with self.assertRaises(AssertionError):
975                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
976
977         out_str = "".join(out_lines)
978         self.assertTrue("Expected tree:" in out_str)
979         self.assertTrue("Actual tree:" in out_str)
980         self.assertEqual("".join(err_lines), "")
981
982     def test_cache_broken_file(self) -> None:
983         mode = black.FileMode()
984         with cache_dir() as workspace:
985             cache_file = black.get_cache_file(mode)
986             with cache_file.open("w") as fobj:
987                 fobj.write("this is not a pickle")
988             self.assertEqual(black.read_cache(mode), {})
989             src = (workspace / "test.py").resolve()
990             with src.open("w") as fobj:
991                 fobj.write("print('hello')")
992             self.invokeBlack([str(src)])
993             cache = black.read_cache(mode)
994             self.assertIn(src, cache)
995
996     def test_cache_single_file_already_cached(self) -> None:
997         mode = black.FileMode()
998         with cache_dir() as workspace:
999             src = (workspace / "test.py").resolve()
1000             with src.open("w") as fobj:
1001                 fobj.write("print('hello')")
1002             black.write_cache({}, [src], mode)
1003             self.invokeBlack([str(src)])
1004             with src.open("r") as fobj:
1005                 self.assertEqual(fobj.read(), "print('hello')")
1006
1007     @event_loop(close=False)
1008     def test_cache_multiple_files(self) -> None:
1009         mode = black.FileMode()
1010         with cache_dir() as workspace, patch(
1011             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1012         ):
1013             one = (workspace / "one.py").resolve()
1014             with one.open("w") as fobj:
1015                 fobj.write("print('hello')")
1016             two = (workspace / "two.py").resolve()
1017             with two.open("w") as fobj:
1018                 fobj.write("print('hello')")
1019             black.write_cache({}, [one], mode)
1020             self.invokeBlack([str(workspace)])
1021             with one.open("r") as fobj:
1022                 self.assertEqual(fobj.read(), "print('hello')")
1023             with two.open("r") as fobj:
1024                 self.assertEqual(fobj.read(), 'print("hello")\n')
1025             cache = black.read_cache(mode)
1026             self.assertIn(one, cache)
1027             self.assertIn(two, cache)
1028
1029     def test_no_cache_when_writeback_diff(self) -> None:
1030         mode = black.FileMode()
1031         with cache_dir() as workspace:
1032             src = (workspace / "test.py").resolve()
1033             with src.open("w") as fobj:
1034                 fobj.write("print('hello')")
1035             self.invokeBlack([str(src), "--diff"])
1036             cache_file = black.get_cache_file(mode)
1037             self.assertFalse(cache_file.exists())
1038
1039     def test_no_cache_when_stdin(self) -> None:
1040         mode = black.FileMode()
1041         with cache_dir():
1042             result = CliRunner().invoke(
1043                 black.main, ["-"], input=BytesIO(b"print('hello')")
1044             )
1045             self.assertEqual(result.exit_code, 0)
1046             cache_file = black.get_cache_file(mode)
1047             self.assertFalse(cache_file.exists())
1048
1049     def test_read_cache_no_cachefile(self) -> None:
1050         mode = black.FileMode()
1051         with cache_dir():
1052             self.assertEqual(black.read_cache(mode), {})
1053
1054     def test_write_cache_read_cache(self) -> None:
1055         mode = black.FileMode()
1056         with cache_dir() as workspace:
1057             src = (workspace / "test.py").resolve()
1058             src.touch()
1059             black.write_cache({}, [src], mode)
1060             cache = black.read_cache(mode)
1061             self.assertIn(src, cache)
1062             self.assertEqual(cache[src], black.get_cache_info(src))
1063
1064     def test_filter_cached(self) -> None:
1065         with TemporaryDirectory() as workspace:
1066             path = Path(workspace)
1067             uncached = (path / "uncached").resolve()
1068             cached = (path / "cached").resolve()
1069             cached_but_changed = (path / "changed").resolve()
1070             uncached.touch()
1071             cached.touch()
1072             cached_but_changed.touch()
1073             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1074             todo, done = black.filter_cached(
1075                 cache, {uncached, cached, cached_but_changed}
1076             )
1077             self.assertEqual(todo, {uncached, cached_but_changed})
1078             self.assertEqual(done, {cached})
1079
1080     def test_write_cache_creates_directory_if_needed(self) -> None:
1081         mode = black.FileMode()
1082         with cache_dir(exists=False) as workspace:
1083             self.assertFalse(workspace.exists())
1084             black.write_cache({}, [], mode)
1085             self.assertTrue(workspace.exists())
1086
1087     @event_loop(close=False)
1088     def test_failed_formatting_does_not_get_cached(self) -> None:
1089         mode = black.FileMode()
1090         with cache_dir() as workspace, patch(
1091             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1092         ):
1093             failing = (workspace / "failing.py").resolve()
1094             with failing.open("w") as fobj:
1095                 fobj.write("not actually python")
1096             clean = (workspace / "clean.py").resolve()
1097             with clean.open("w") as fobj:
1098                 fobj.write('print("hello")\n')
1099             self.invokeBlack([str(workspace)], exit_code=123)
1100             cache = black.read_cache(mode)
1101             self.assertNotIn(failing, cache)
1102             self.assertIn(clean, cache)
1103
1104     def test_write_cache_write_fail(self) -> None:
1105         mode = black.FileMode()
1106         with cache_dir(), patch.object(Path, "open") as mock:
1107             mock.side_effect = OSError
1108             black.write_cache({}, [], mode)
1109
1110     @event_loop(close=False)
1111     def test_check_diff_use_together(self) -> None:
1112         with cache_dir():
1113             # Files which will be reformatted.
1114             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1115             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1116             # Files which will not be reformatted.
1117             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1118             self.invokeBlack([str(src2), "--diff", "--check"])
1119             # Multi file command.
1120             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1121
1122     def test_no_files(self) -> None:
1123         with cache_dir():
1124             # Without an argument, black exits with error code 0.
1125             self.invokeBlack([])
1126
1127     def test_broken_symlink(self) -> None:
1128         with cache_dir() as workspace:
1129             symlink = workspace / "broken_link.py"
1130             try:
1131                 symlink.symlink_to("nonexistent.py")
1132             except OSError as e:
1133                 self.skipTest(f"Can't create symlinks: {e}")
1134             self.invokeBlack([str(workspace.resolve())])
1135
1136     def test_read_cache_line_lengths(self) -> None:
1137         mode = black.FileMode()
1138         short_mode = black.FileMode(line_length=1)
1139         with cache_dir() as workspace:
1140             path = (workspace / "file.py").resolve()
1141             path.touch()
1142             black.write_cache({}, [path], mode)
1143             one = black.read_cache(mode)
1144             self.assertIn(path, one)
1145             two = black.read_cache(short_mode)
1146             self.assertNotIn(path, two)
1147
1148     def test_single_file_force_pyi(self) -> None:
1149         reg_mode = black.FileMode()
1150         pyi_mode = black.FileMode(is_pyi=True)
1151         contents, expected = read_data("force_pyi")
1152         with cache_dir() as workspace:
1153             path = (workspace / "file.py").resolve()
1154             with open(path, "w") as fh:
1155                 fh.write(contents)
1156             self.invokeBlack([str(path), "--pyi"])
1157             with open(path, "r") as fh:
1158                 actual = fh.read()
1159             # verify cache with --pyi is separate
1160             pyi_cache = black.read_cache(pyi_mode)
1161             self.assertIn(path, pyi_cache)
1162             normal_cache = black.read_cache(reg_mode)
1163             self.assertNotIn(path, normal_cache)
1164         self.assertEqual(actual, expected)
1165
1166     @event_loop(close=False)
1167     def test_multi_file_force_pyi(self) -> None:
1168         reg_mode = black.FileMode()
1169         pyi_mode = black.FileMode(is_pyi=True)
1170         contents, expected = read_data("force_pyi")
1171         with cache_dir() as workspace:
1172             paths = [
1173                 (workspace / "file1.py").resolve(),
1174                 (workspace / "file2.py").resolve(),
1175             ]
1176             for path in paths:
1177                 with open(path, "w") as fh:
1178                     fh.write(contents)
1179             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1180             for path in paths:
1181                 with open(path, "r") as fh:
1182                     actual = fh.read()
1183                 self.assertEqual(actual, expected)
1184             # verify cache with --pyi is separate
1185             pyi_cache = black.read_cache(pyi_mode)
1186             normal_cache = black.read_cache(reg_mode)
1187             for path in paths:
1188                 self.assertIn(path, pyi_cache)
1189                 self.assertNotIn(path, normal_cache)
1190
1191     def test_pipe_force_pyi(self) -> None:
1192         source, expected = read_data("force_pyi")
1193         result = CliRunner().invoke(
1194             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1195         )
1196         self.assertEqual(result.exit_code, 0)
1197         actual = result.output
1198         self.assertFormatEqual(actual, expected)
1199
1200     def test_single_file_force_py36(self) -> None:
1201         reg_mode = black.FileMode()
1202         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1203         source, expected = read_data("force_py36")
1204         with cache_dir() as workspace:
1205             path = (workspace / "file.py").resolve()
1206             with open(path, "w") as fh:
1207                 fh.write(source)
1208             self.invokeBlack([str(path), *PY36_ARGS])
1209             with open(path, "r") as fh:
1210                 actual = fh.read()
1211             # verify cache with --target-version is separate
1212             py36_cache = black.read_cache(py36_mode)
1213             self.assertIn(path, py36_cache)
1214             normal_cache = black.read_cache(reg_mode)
1215             self.assertNotIn(path, normal_cache)
1216         self.assertEqual(actual, expected)
1217
1218     @event_loop(close=False)
1219     def test_multi_file_force_py36(self) -> None:
1220         reg_mode = black.FileMode()
1221         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1222         source, expected = read_data("force_py36")
1223         with cache_dir() as workspace:
1224             paths = [
1225                 (workspace / "file1.py").resolve(),
1226                 (workspace / "file2.py").resolve(),
1227             ]
1228             for path in paths:
1229                 with open(path, "w") as fh:
1230                     fh.write(source)
1231             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1232             for path in paths:
1233                 with open(path, "r") as fh:
1234                     actual = fh.read()
1235                 self.assertEqual(actual, expected)
1236             # verify cache with --target-version is separate
1237             pyi_cache = black.read_cache(py36_mode)
1238             normal_cache = black.read_cache(reg_mode)
1239             for path in paths:
1240                 self.assertIn(path, pyi_cache)
1241                 self.assertNotIn(path, normal_cache)
1242
1243     def test_pipe_force_py36(self) -> None:
1244         source, expected = read_data("force_py36")
1245         result = CliRunner().invoke(
1246             black.main,
1247             ["-", "-q", "--target-version=py36"],
1248             input=BytesIO(source.encode("utf8")),
1249         )
1250         self.assertEqual(result.exit_code, 0)
1251         actual = result.output
1252         self.assertFormatEqual(actual, expected)
1253
1254     def test_include_exclude(self) -> None:
1255         path = THIS_DIR / "data" / "include_exclude_tests"
1256         include = re.compile(r"\.pyi?$")
1257         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1258         report = black.Report()
1259         sources: List[Path] = []
1260         expected = [
1261             Path(path / "b/dont_exclude/a.py"),
1262             Path(path / "b/dont_exclude/a.pyi"),
1263         ]
1264         this_abs = THIS_DIR.resolve()
1265         sources.extend(
1266             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1267         )
1268         self.assertEqual(sorted(expected), sorted(sources))
1269
1270     def test_empty_include(self) -> None:
1271         path = THIS_DIR / "data" / "include_exclude_tests"
1272         report = black.Report()
1273         empty = re.compile(r"")
1274         sources: List[Path] = []
1275         expected = [
1276             Path(path / "b/exclude/a.pie"),
1277             Path(path / "b/exclude/a.py"),
1278             Path(path / "b/exclude/a.pyi"),
1279             Path(path / "b/dont_exclude/a.pie"),
1280             Path(path / "b/dont_exclude/a.py"),
1281             Path(path / "b/dont_exclude/a.pyi"),
1282             Path(path / "b/.definitely_exclude/a.pie"),
1283             Path(path / "b/.definitely_exclude/a.py"),
1284             Path(path / "b/.definitely_exclude/a.pyi"),
1285         ]
1286         this_abs = THIS_DIR.resolve()
1287         sources.extend(
1288             black.gen_python_files_in_dir(
1289                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1290             )
1291         )
1292         self.assertEqual(sorted(expected), sorted(sources))
1293
1294     def test_empty_exclude(self) -> None:
1295         path = THIS_DIR / "data" / "include_exclude_tests"
1296         report = black.Report()
1297         empty = re.compile(r"")
1298         sources: List[Path] = []
1299         expected = [
1300             Path(path / "b/dont_exclude/a.py"),
1301             Path(path / "b/dont_exclude/a.pyi"),
1302             Path(path / "b/exclude/a.py"),
1303             Path(path / "b/exclude/a.pyi"),
1304             Path(path / "b/.definitely_exclude/a.py"),
1305             Path(path / "b/.definitely_exclude/a.pyi"),
1306         ]
1307         this_abs = THIS_DIR.resolve()
1308         sources.extend(
1309             black.gen_python_files_in_dir(
1310                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1311             )
1312         )
1313         self.assertEqual(sorted(expected), sorted(sources))
1314
1315     def test_invalid_include_exclude(self) -> None:
1316         for option in ["--include", "--exclude"]:
1317             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1318
1319     def test_preserves_line_endings(self) -> None:
1320         with TemporaryDirectory() as workspace:
1321             test_file = Path(workspace) / "test.py"
1322             for nl in ["\n", "\r\n"]:
1323                 contents = nl.join(["def f(  ):", "    pass"])
1324                 test_file.write_bytes(contents.encode())
1325                 ff(test_file, write_back=black.WriteBack.YES)
1326                 updated_contents: bytes = test_file.read_bytes()
1327                 self.assertIn(nl.encode(), updated_contents)
1328                 if nl == "\n":
1329                     self.assertNotIn(b"\r\n", updated_contents)
1330
1331     def test_preserves_line_endings_via_stdin(self) -> None:
1332         for nl in ["\n", "\r\n"]:
1333             contents = nl.join(["def f(  ):", "    pass"])
1334             runner = BlackRunner()
1335             result = runner.invoke(
1336                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1337             )
1338             self.assertEqual(result.exit_code, 0)
1339             output = runner.stdout_bytes
1340             self.assertIn(nl.encode("utf8"), output)
1341             if nl == "\n":
1342                 self.assertNotIn(b"\r\n", output)
1343
1344     def test_assert_equivalent_different_asts(self) -> None:
1345         with self.assertRaises(AssertionError):
1346             black.assert_equivalent("{}", "None")
1347
1348     def test_symlink_out_of_root_directory(self) -> None:
1349         path = MagicMock()
1350         root = THIS_DIR
1351         child = MagicMock()
1352         include = re.compile(black.DEFAULT_INCLUDES)
1353         exclude = re.compile(black.DEFAULT_EXCLUDES)
1354         report = black.Report()
1355         # `child` should behave like a symlink which resolved path is clearly
1356         # outside of the `root` directory.
1357         path.iterdir.return_value = [child]
1358         child.resolve.return_value = Path("/a/b/c")
1359         child.is_symlink.return_value = True
1360         try:
1361             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1362         except ValueError as ve:
1363             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1364         path.iterdir.assert_called_once()
1365         child.resolve.assert_called_once()
1366         child.is_symlink.assert_called_once()
1367         # `child` should behave like a strange file which resolved path is clearly
1368         # outside of the `root` directory.
1369         child.is_symlink.return_value = False
1370         with self.assertRaises(ValueError):
1371             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1372         path.iterdir.assert_called()
1373         self.assertEqual(path.iterdir.call_count, 2)
1374         child.resolve.assert_called()
1375         self.assertEqual(child.resolve.call_count, 2)
1376         child.is_symlink.assert_called()
1377         self.assertEqual(child.is_symlink.call_count, 2)
1378
1379     def test_shhh_click(self) -> None:
1380         try:
1381             from click import _unicodefun  # type: ignore
1382         except ModuleNotFoundError:
1383             self.skipTest("Incompatible Click version")
1384         if not hasattr(_unicodefun, "_verify_python3_env"):
1385             self.skipTest("Incompatible Click version")
1386         # First, let's see if Click is crashing with a preferred ASCII charset.
1387         with patch("locale.getpreferredencoding") as gpe:
1388             gpe.return_value = "ASCII"
1389             with self.assertRaises(RuntimeError):
1390                 _unicodefun._verify_python3_env()
1391         # Now, let's silence Click...
1392         black.patch_click()
1393         # ...and confirm it's silent.
1394         with patch("locale.getpreferredencoding") as gpe:
1395             gpe.return_value = "ASCII"
1396             try:
1397                 _unicodefun._verify_python3_env()
1398             except RuntimeError as re:
1399                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1400
1401     def test_root_logger_not_used_directly(self) -> None:
1402         def fail(*args: Any, **kwargs: Any) -> None:
1403             self.fail("Record created with root logger")
1404
1405         with patch.multiple(
1406             logging.root,
1407             debug=fail,
1408             info=fail,
1409             warning=fail,
1410             error=fail,
1411             critical=fail,
1412             log=fail,
1413         ):
1414             ff(THIS_FILE)
1415
1416     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1417     @async_test
1418     async def test_blackd_request_needs_formatting(self) -> None:
1419         app = blackd.make_app()
1420         async with TestClient(TestServer(app)) as client:
1421             response = await client.post("/", data=b"print('hello world')")
1422             self.assertEqual(response.status, 200)
1423             self.assertEqual(response.charset, "utf8")
1424             self.assertEqual(await response.read(), b'print("hello world")\n')
1425
1426     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1427     @async_test
1428     async def test_blackd_request_no_change(self) -> None:
1429         app = blackd.make_app()
1430         async with TestClient(TestServer(app)) as client:
1431             response = await client.post("/", data=b'print("hello world")\n')
1432             self.assertEqual(response.status, 204)
1433             self.assertEqual(await response.read(), b"")
1434
1435     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1436     @async_test
1437     async def test_blackd_request_syntax_error(self) -> None:
1438         app = blackd.make_app()
1439         async with TestClient(TestServer(app)) as client:
1440             response = await client.post("/", data=b"what even ( is")
1441             self.assertEqual(response.status, 400)
1442             content = await response.text()
1443             self.assertTrue(
1444                 content.startswith("Cannot parse"),
1445                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1446             )
1447
1448     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1449     @async_test
1450     async def test_blackd_unsupported_version(self) -> None:
1451         app = blackd.make_app()
1452         async with TestClient(TestServer(app)) as client:
1453             response = await client.post(
1454                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1455             )
1456             self.assertEqual(response.status, 501)
1457
1458     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1459     @async_test
1460     async def test_blackd_supported_version(self) -> None:
1461         app = blackd.make_app()
1462         async with TestClient(TestServer(app)) as client:
1463             response = await client.post(
1464                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1465             )
1466             self.assertEqual(response.status, 200)
1467
1468     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1469     @async_test
1470     async def test_blackd_invalid_python_variant(self) -> None:
1471         app = blackd.make_app()
1472         async with TestClient(TestServer(app)) as client:
1473
1474             async def check(header_value: str, expected_status: int = 400) -> None:
1475                 response = await client.post(
1476                     "/",
1477                     data=b"what",
1478                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1479                 )
1480                 self.assertEqual(response.status, expected_status)
1481
1482             await check("lol")
1483             await check("ruby3.5")
1484             await check("pyi3.6")
1485             await check("py1.5")
1486             await check("2.8")
1487             await check("py2.8")
1488             await check("3.0")
1489             await check("pypy3.0")
1490             await check("jython3.4")
1491
1492     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1493     @async_test
1494     async def test_blackd_pyi(self) -> None:
1495         app = blackd.make_app()
1496         async with TestClient(TestServer(app)) as client:
1497             source, expected = read_data("stub.pyi")
1498             response = await client.post(
1499                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1500             )
1501             self.assertEqual(response.status, 200)
1502             self.assertEqual(await response.text(), expected)
1503
1504     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1505     @async_test
1506     async def test_blackd_python_variant(self) -> None:
1507         app = blackd.make_app()
1508         code = (
1509             "def f(\n"
1510             "    and_has_a_bunch_of,\n"
1511             "    very_long_arguments_too,\n"
1512             "    and_lots_of_them_as_well_lol,\n"
1513             "    **and_very_long_keyword_arguments\n"
1514             "):\n"
1515             "    pass\n"
1516         )
1517         async with TestClient(TestServer(app)) as client:
1518
1519             async def check(header_value: str, expected_status: int) -> None:
1520                 response = await client.post(
1521                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1522                 )
1523                 self.assertEqual(response.status, expected_status)
1524
1525             await check("3.6", 200)
1526             await check("py3.6", 200)
1527             await check("3.5,3.7", 200)
1528             await check("3.5,py3.7", 200)
1529
1530             await check("2", 204)
1531             await check("2.7", 204)
1532             await check("py2.7", 204)
1533             await check("3.4", 204)
1534             await check("py3.4", 204)
1535
1536     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1537     @async_test
1538     async def test_blackd_fast(self) -> None:
1539         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1540             app = blackd.make_app()
1541             async with TestClient(TestServer(app)) as client:
1542                 response = await client.post("/", data=b"ur'hello'")
1543                 self.assertEqual(response.status, 500)
1544                 self.assertIn("failed to parse source file", await response.text())
1545                 response = await client.post(
1546                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1547                 )
1548                 self.assertEqual(response.status, 200)
1549
1550     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1551     @async_test
1552     async def test_blackd_line_length(self) -> None:
1553         app = blackd.make_app()
1554         async with TestClient(TestServer(app)) as client:
1555             response = await client.post(
1556                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1557             )
1558             self.assertEqual(response.status, 200)
1559
1560     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1561     @async_test
1562     async def test_blackd_invalid_line_length(self) -> None:
1563         app = blackd.make_app()
1564         async with TestClient(TestServer(app)) as client:
1565             response = await client.post(
1566                 "/",
1567                 data=b'print("hello")\n',
1568                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1569             )
1570             self.assertEqual(response.status, 400)
1571
1572     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1573     def test_blackd_main(self) -> None:
1574         with patch("blackd.web.run_app"):
1575             result = CliRunner().invoke(blackd.main, [])
1576             if result.exception is not None:
1577                 raise result.exception
1578             self.assertEqual(result.exit_code, 0)
1579
1580
1581 if __name__ == "__main__":
1582     unittest.main(module="test_black")