]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

ambv/black -> python/black (#819)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_cantfit(self) -> None:
393         source, expected = read_data("cantfit")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_import_spacing(self) -> None:
401         source, expected = read_data("import_spacing")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_composition(self) -> None:
409         source, expected = read_data("composition")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_empty_lines(self) -> None:
417         source, expected = read_data("empty_lines")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_string_prefixes(self) -> None:
425         source, expected = read_data("string_prefixes")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_numeric_literals(self) -> None:
433         source, expected = read_data("numeric_literals")
434         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
435         actual = fs(source, mode=mode)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, mode)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_numeric_literals_ignoring_underscores(self) -> None:
442         source, expected = read_data("numeric_literals_skip_underscores")
443         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
444         actual = fs(source, mode=mode)
445         self.assertFormatEqual(expected, actual)
446         black.assert_equivalent(source, actual)
447         black.assert_stable(source, actual, mode)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_numeric_literals_py2(self) -> None:
451         source, expected = read_data("numeric_literals_py2")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2(self) -> None:
458         source, expected = read_data("python2")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         # black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_python2_print_function(self) -> None:
466         source, expected = read_data("python2_print_function")
467         mode = black.FileMode(target_versions={TargetVersion.PY27})
468         actual = fs(source, mode=mode)
469         self.assertFormatEqual(expected, actual)
470         black.assert_stable(source, actual, mode)
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2_unicode_literals(self) -> None:
474         source, expected = read_data("python2_unicode_literals")
475         actual = fs(source)
476         self.assertFormatEqual(expected, actual)
477         black.assert_stable(source, actual, black.FileMode())
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_stub(self) -> None:
481         mode = black.FileMode(is_pyi=True)
482         source, expected = read_data("stub.pyi")
483         actual = fs(source, mode=mode)
484         self.assertFormatEqual(expected, actual)
485         black.assert_stable(source, actual, mode)
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_python37(self) -> None:
489         source, expected = read_data("python37")
490         actual = fs(source)
491         self.assertFormatEqual(expected, actual)
492         major, minor = sys.version_info[:2]
493         if major > 3 or (major == 3 and minor >= 7):
494             black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_fmtonoff(self) -> None:
499         source, expected = read_data("fmtonoff")
500         actual = fs(source)
501         self.assertFormatEqual(expected, actual)
502         black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_fmtonoff2(self) -> None:
507         source, expected = read_data("fmtonoff2")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_remove_empty_parentheses_after_class(self) -> None:
515         source, expected = read_data("class_blank_parentheses")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_new_line_between_class_and_code(self) -> None:
523         source, expected = read_data("class_methods_new_line")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_bracket_match(self) -> None:
531         source, expected = read_data("bracketmatch")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, black.FileMode())
536
537     def test_tab_comment_indentation(self) -> None:
538         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
539         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
540         self.assertFormatEqual(contents_spc, fs(contents_spc))
541         self.assertFormatEqual(contents_spc, fs(contents_tab))
542
543         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
544         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
545         self.assertFormatEqual(contents_spc, fs(contents_spc))
546         self.assertFormatEqual(contents_spc, fs(contents_tab))
547
548         # mixed tabs and spaces (valid Python 2 code)
549         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
550         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
551         self.assertFormatEqual(contents_spc, fs(contents_spc))
552         self.assertFormatEqual(contents_spc, fs(contents_tab))
553
554         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
555         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
556         self.assertFormatEqual(contents_spc, fs(contents_spc))
557         self.assertFormatEqual(contents_spc, fs(contents_tab))
558
559     def test_report_verbose(self) -> None:
560         report = black.Report(verbose=True)
561         out_lines = []
562         err_lines = []
563
564         def out(msg: str, **kwargs: Any) -> None:
565             out_lines.append(msg)
566
567         def err(msg: str, **kwargs: Any) -> None:
568             err_lines.append(msg)
569
570         with patch("black.out", out), patch("black.err", err):
571             report.done(Path("f1"), black.Changed.NO)
572             self.assertEqual(len(out_lines), 1)
573             self.assertEqual(len(err_lines), 0)
574             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
575             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
576             self.assertEqual(report.return_code, 0)
577             report.done(Path("f2"), black.Changed.YES)
578             self.assertEqual(len(out_lines), 2)
579             self.assertEqual(len(err_lines), 0)
580             self.assertEqual(out_lines[-1], "reformatted f2")
581             self.assertEqual(
582                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
583             )
584             report.done(Path("f3"), black.Changed.CACHED)
585             self.assertEqual(len(out_lines), 3)
586             self.assertEqual(len(err_lines), 0)
587             self.assertEqual(
588                 out_lines[-1], "f3 wasn't modified on disk since last run."
589             )
590             self.assertEqual(
591                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
592             )
593             self.assertEqual(report.return_code, 0)
594             report.check = True
595             self.assertEqual(report.return_code, 1)
596             report.check = False
597             report.failed(Path("e1"), "boom")
598             self.assertEqual(len(out_lines), 3)
599             self.assertEqual(len(err_lines), 1)
600             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "1 file reformatted, 2 files left unchanged, "
604                 "1 file failed to reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.done(Path("f3"), black.Changed.YES)
608             self.assertEqual(len(out_lines), 4)
609             self.assertEqual(len(err_lines), 1)
610             self.assertEqual(out_lines[-1], "reformatted f3")
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 2 files left unchanged, "
614                 "1 file failed to reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.failed(Path("e2"), "boom")
618             self.assertEqual(len(out_lines), 4)
619             self.assertEqual(len(err_lines), 2)
620             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
621             self.assertEqual(
622                 unstyle(str(report)),
623                 "2 files reformatted, 2 files left unchanged, "
624                 "2 files failed to reformat.",
625             )
626             self.assertEqual(report.return_code, 123)
627             report.path_ignored(Path("wat"), "no match")
628             self.assertEqual(len(out_lines), 5)
629             self.assertEqual(len(err_lines), 2)
630             self.assertEqual(out_lines[-1], "wat ignored: no match")
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files reformatted, 2 files left unchanged, "
634                 "2 files failed to reformat.",
635             )
636             self.assertEqual(report.return_code, 123)
637             report.done(Path("f4"), black.Changed.NO)
638             self.assertEqual(len(out_lines), 6)
639             self.assertEqual(len(err_lines), 2)
640             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
641             self.assertEqual(
642                 unstyle(str(report)),
643                 "2 files reformatted, 3 files left unchanged, "
644                 "2 files failed to reformat.",
645             )
646             self.assertEqual(report.return_code, 123)
647             report.check = True
648             self.assertEqual(
649                 unstyle(str(report)),
650                 "2 files would be reformatted, 3 files would be left unchanged, "
651                 "2 files would fail to reformat.",
652             )
653
654     def test_report_quiet(self) -> None:
655         report = black.Report(quiet=True)
656         out_lines = []
657         err_lines = []
658
659         def out(msg: str, **kwargs: Any) -> None:
660             out_lines.append(msg)
661
662         def err(msg: str, **kwargs: Any) -> None:
663             err_lines.append(msg)
664
665         with patch("black.out", out), patch("black.err", err):
666             report.done(Path("f1"), black.Changed.NO)
667             self.assertEqual(len(out_lines), 0)
668             self.assertEqual(len(err_lines), 0)
669             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
670             self.assertEqual(report.return_code, 0)
671             report.done(Path("f2"), black.Changed.YES)
672             self.assertEqual(len(out_lines), 0)
673             self.assertEqual(len(err_lines), 0)
674             self.assertEqual(
675                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
676             )
677             report.done(Path("f3"), black.Changed.CACHED)
678             self.assertEqual(len(out_lines), 0)
679             self.assertEqual(len(err_lines), 0)
680             self.assertEqual(
681                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
682             )
683             self.assertEqual(report.return_code, 0)
684             report.check = True
685             self.assertEqual(report.return_code, 1)
686             report.check = False
687             report.failed(Path("e1"), "boom")
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 1)
690             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
691             self.assertEqual(
692                 unstyle(str(report)),
693                 "1 file reformatted, 2 files left unchanged, "
694                 "1 file failed to reformat.",
695             )
696             self.assertEqual(report.return_code, 123)
697             report.done(Path("f3"), black.Changed.YES)
698             self.assertEqual(len(out_lines), 0)
699             self.assertEqual(len(err_lines), 1)
700             self.assertEqual(
701                 unstyle(str(report)),
702                 "2 files reformatted, 2 files left unchanged, "
703                 "1 file failed to reformat.",
704             )
705             self.assertEqual(report.return_code, 123)
706             report.failed(Path("e2"), "boom")
707             self.assertEqual(len(out_lines), 0)
708             self.assertEqual(len(err_lines), 2)
709             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files reformatted, 2 files left unchanged, "
713                 "2 files failed to reformat.",
714             )
715             self.assertEqual(report.return_code, 123)
716             report.path_ignored(Path("wat"), "no match")
717             self.assertEqual(len(out_lines), 0)
718             self.assertEqual(len(err_lines), 2)
719             self.assertEqual(
720                 unstyle(str(report)),
721                 "2 files reformatted, 2 files left unchanged, "
722                 "2 files failed to reformat.",
723             )
724             self.assertEqual(report.return_code, 123)
725             report.done(Path("f4"), black.Changed.NO)
726             self.assertEqual(len(out_lines), 0)
727             self.assertEqual(len(err_lines), 2)
728             self.assertEqual(
729                 unstyle(str(report)),
730                 "2 files reformatted, 3 files left unchanged, "
731                 "2 files failed to reformat.",
732             )
733             self.assertEqual(report.return_code, 123)
734             report.check = True
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files would be reformatted, 3 files would be left unchanged, "
738                 "2 files would fail to reformat.",
739             )
740
741     def test_report_normal(self) -> None:
742         report = black.Report()
743         out_lines = []
744         err_lines = []
745
746         def out(msg: str, **kwargs: Any) -> None:
747             out_lines.append(msg)
748
749         def err(msg: str, **kwargs: Any) -> None:
750             err_lines.append(msg)
751
752         with patch("black.out", out), patch("black.err", err):
753             report.done(Path("f1"), black.Changed.NO)
754             self.assertEqual(len(out_lines), 0)
755             self.assertEqual(len(err_lines), 0)
756             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
757             self.assertEqual(report.return_code, 0)
758             report.done(Path("f2"), black.Changed.YES)
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 0)
761             self.assertEqual(out_lines[-1], "reformatted f2")
762             self.assertEqual(
763                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
764             )
765             report.done(Path("f3"), black.Changed.CACHED)
766             self.assertEqual(len(out_lines), 1)
767             self.assertEqual(len(err_lines), 0)
768             self.assertEqual(out_lines[-1], "reformatted f2")
769             self.assertEqual(
770                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
771             )
772             self.assertEqual(report.return_code, 0)
773             report.check = True
774             self.assertEqual(report.return_code, 1)
775             report.check = False
776             report.failed(Path("e1"), "boom")
777             self.assertEqual(len(out_lines), 1)
778             self.assertEqual(len(err_lines), 1)
779             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "1 file reformatted, 2 files left unchanged, "
783                 "1 file failed to reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.done(Path("f3"), black.Changed.YES)
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 1)
789             self.assertEqual(out_lines[-1], "reformatted f3")
790             self.assertEqual(
791                 unstyle(str(report)),
792                 "2 files reformatted, 2 files left unchanged, "
793                 "1 file failed to reformat.",
794             )
795             self.assertEqual(report.return_code, 123)
796             report.failed(Path("e2"), "boom")
797             self.assertEqual(len(out_lines), 2)
798             self.assertEqual(len(err_lines), 2)
799             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 2 files left unchanged, "
803                 "2 files failed to reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.path_ignored(Path("wat"), "no match")
807             self.assertEqual(len(out_lines), 2)
808             self.assertEqual(len(err_lines), 2)
809             self.assertEqual(
810                 unstyle(str(report)),
811                 "2 files reformatted, 2 files left unchanged, "
812                 "2 files failed to reformat.",
813             )
814             self.assertEqual(report.return_code, 123)
815             report.done(Path("f4"), black.Changed.NO)
816             self.assertEqual(len(out_lines), 2)
817             self.assertEqual(len(err_lines), 2)
818             self.assertEqual(
819                 unstyle(str(report)),
820                 "2 files reformatted, 3 files left unchanged, "
821                 "2 files failed to reformat.",
822             )
823             self.assertEqual(report.return_code, 123)
824             report.check = True
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files would be reformatted, 3 files would be left unchanged, "
828                 "2 files would fail to reformat.",
829             )
830
831     def test_lib2to3_parse(self) -> None:
832         with self.assertRaises(black.InvalidInput):
833             black.lib2to3_parse("invalid syntax")
834
835         straddling = "x + y"
836         black.lib2to3_parse(straddling)
837         black.lib2to3_parse(straddling, {TargetVersion.PY27})
838         black.lib2to3_parse(straddling, {TargetVersion.PY36})
839         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
840
841         py2_only = "print x"
842         black.lib2to3_parse(py2_only)
843         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
844         with self.assertRaises(black.InvalidInput):
845             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
846         with self.assertRaises(black.InvalidInput):
847             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
848
849         py3_only = "exec(x, end=y)"
850         black.lib2to3_parse(py3_only)
851         with self.assertRaises(black.InvalidInput):
852             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
853         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
854         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
855
856     def test_get_features_used(self) -> None:
857         node = black.lib2to3_parse("def f(*, arg): ...\n")
858         self.assertEqual(black.get_features_used(node), set())
859         node = black.lib2to3_parse("def f(*, arg,): ...\n")
860         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
861         node = black.lib2to3_parse("f(*arg,)\n")
862         self.assertEqual(
863             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
864         )
865         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
866         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
867         node = black.lib2to3_parse("123_456\n")
868         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
869         node = black.lib2to3_parse("123456\n")
870         self.assertEqual(black.get_features_used(node), set())
871         source, expected = read_data("function")
872         node = black.lib2to3_parse(source)
873         expected_features = {
874             Feature.TRAILING_COMMA_IN_CALL,
875             Feature.TRAILING_COMMA_IN_DEF,
876             Feature.F_STRINGS,
877         }
878         self.assertEqual(black.get_features_used(node), expected_features)
879         node = black.lib2to3_parse(expected)
880         self.assertEqual(black.get_features_used(node), expected_features)
881         source, expected = read_data("expression")
882         node = black.lib2to3_parse(source)
883         self.assertEqual(black.get_features_used(node), set())
884         node = black.lib2to3_parse(expected)
885         self.assertEqual(black.get_features_used(node), set())
886
887     def test_get_future_imports(self) -> None:
888         node = black.lib2to3_parse("\n")
889         self.assertEqual(set(), black.get_future_imports(node))
890         node = black.lib2to3_parse("from __future__ import black\n")
891         self.assertEqual({"black"}, black.get_future_imports(node))
892         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
893         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
894         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
895         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
896         node = black.lib2to3_parse(
897             "from __future__ import multiple\nfrom __future__ import imports\n"
898         )
899         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
900         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
901         self.assertEqual({"black"}, black.get_future_imports(node))
902         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
903         self.assertEqual({"black"}, black.get_future_imports(node))
904         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
905         self.assertEqual(set(), black.get_future_imports(node))
906         node = black.lib2to3_parse("from some.module import black\n")
907         self.assertEqual(set(), black.get_future_imports(node))
908         node = black.lib2to3_parse(
909             "from __future__ import unicode_literals as _unicode_literals"
910         )
911         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
912         node = black.lib2to3_parse(
913             "from __future__ import unicode_literals as _lol, print"
914         )
915         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
916
917     def test_debug_visitor(self) -> None:
918         source, _ = read_data("debug_visitor.py")
919         expected, _ = read_data("debug_visitor.out")
920         out_lines = []
921         err_lines = []
922
923         def out(msg: str, **kwargs: Any) -> None:
924             out_lines.append(msg)
925
926         def err(msg: str, **kwargs: Any) -> None:
927             err_lines.append(msg)
928
929         with patch("black.out", out), patch("black.err", err):
930             black.DebugVisitor.show(source)
931         actual = "\n".join(out_lines) + "\n"
932         log_name = ""
933         if expected != actual:
934             log_name = black.dump_to_file(*out_lines)
935         self.assertEqual(
936             expected,
937             actual,
938             f"AST print out is different. Actual version dumped to {log_name}",
939         )
940
941     def test_format_file_contents(self) -> None:
942         empty = ""
943         mode = black.FileMode()
944         with self.assertRaises(black.NothingChanged):
945             black.format_file_contents(empty, mode=mode, fast=False)
946         just_nl = "\n"
947         with self.assertRaises(black.NothingChanged):
948             black.format_file_contents(just_nl, mode=mode, fast=False)
949         same = "l = [1, 2, 3]\n"
950         with self.assertRaises(black.NothingChanged):
951             black.format_file_contents(same, mode=mode, fast=False)
952         different = "l = [1,2,3]"
953         expected = same
954         actual = black.format_file_contents(different, mode=mode, fast=False)
955         self.assertEqual(expected, actual)
956         invalid = "return if you can"
957         with self.assertRaises(black.InvalidInput) as e:
958             black.format_file_contents(invalid, mode=mode, fast=False)
959         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
960
961     def test_endmarker(self) -> None:
962         n = black.lib2to3_parse("\n")
963         self.assertEqual(n.type, black.syms.file_input)
964         self.assertEqual(len(n.children), 1)
965         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
966
967     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
968     def test_assertFormatEqual(self) -> None:
969         out_lines = []
970         err_lines = []
971
972         def out(msg: str, **kwargs: Any) -> None:
973             out_lines.append(msg)
974
975         def err(msg: str, **kwargs: Any) -> None:
976             err_lines.append(msg)
977
978         with patch("black.out", out), patch("black.err", err):
979             with self.assertRaises(AssertionError):
980                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
981
982         out_str = "".join(out_lines)
983         self.assertTrue("Expected tree:" in out_str)
984         self.assertTrue("Actual tree:" in out_str)
985         self.assertEqual("".join(err_lines), "")
986
987     def test_cache_broken_file(self) -> None:
988         mode = black.FileMode()
989         with cache_dir() as workspace:
990             cache_file = black.get_cache_file(mode)
991             with cache_file.open("w") as fobj:
992                 fobj.write("this is not a pickle")
993             self.assertEqual(black.read_cache(mode), {})
994             src = (workspace / "test.py").resolve()
995             with src.open("w") as fobj:
996                 fobj.write("print('hello')")
997             self.invokeBlack([str(src)])
998             cache = black.read_cache(mode)
999             self.assertIn(src, cache)
1000
1001     def test_cache_single_file_already_cached(self) -> None:
1002         mode = black.FileMode()
1003         with cache_dir() as workspace:
1004             src = (workspace / "test.py").resolve()
1005             with src.open("w") as fobj:
1006                 fobj.write("print('hello')")
1007             black.write_cache({}, [src], mode)
1008             self.invokeBlack([str(src)])
1009             with src.open("r") as fobj:
1010                 self.assertEqual(fobj.read(), "print('hello')")
1011
1012     @event_loop(close=False)
1013     def test_cache_multiple_files(self) -> None:
1014         mode = black.FileMode()
1015         with cache_dir() as workspace, patch(
1016             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1017         ):
1018             one = (workspace / "one.py").resolve()
1019             with one.open("w") as fobj:
1020                 fobj.write("print('hello')")
1021             two = (workspace / "two.py").resolve()
1022             with two.open("w") as fobj:
1023                 fobj.write("print('hello')")
1024             black.write_cache({}, [one], mode)
1025             self.invokeBlack([str(workspace)])
1026             with one.open("r") as fobj:
1027                 self.assertEqual(fobj.read(), "print('hello')")
1028             with two.open("r") as fobj:
1029                 self.assertEqual(fobj.read(), 'print("hello")\n')
1030             cache = black.read_cache(mode)
1031             self.assertIn(one, cache)
1032             self.assertIn(two, cache)
1033
1034     def test_no_cache_when_writeback_diff(self) -> None:
1035         mode = black.FileMode()
1036         with cache_dir() as workspace:
1037             src = (workspace / "test.py").resolve()
1038             with src.open("w") as fobj:
1039                 fobj.write("print('hello')")
1040             self.invokeBlack([str(src), "--diff"])
1041             cache_file = black.get_cache_file(mode)
1042             self.assertFalse(cache_file.exists())
1043
1044     def test_no_cache_when_stdin(self) -> None:
1045         mode = black.FileMode()
1046         with cache_dir():
1047             result = CliRunner().invoke(
1048                 black.main, ["-"], input=BytesIO(b"print('hello')")
1049             )
1050             self.assertEqual(result.exit_code, 0)
1051             cache_file = black.get_cache_file(mode)
1052             self.assertFalse(cache_file.exists())
1053
1054     def test_read_cache_no_cachefile(self) -> None:
1055         mode = black.FileMode()
1056         with cache_dir():
1057             self.assertEqual(black.read_cache(mode), {})
1058
1059     def test_write_cache_read_cache(self) -> None:
1060         mode = black.FileMode()
1061         with cache_dir() as workspace:
1062             src = (workspace / "test.py").resolve()
1063             src.touch()
1064             black.write_cache({}, [src], mode)
1065             cache = black.read_cache(mode)
1066             self.assertIn(src, cache)
1067             self.assertEqual(cache[src], black.get_cache_info(src))
1068
1069     def test_filter_cached(self) -> None:
1070         with TemporaryDirectory() as workspace:
1071             path = Path(workspace)
1072             uncached = (path / "uncached").resolve()
1073             cached = (path / "cached").resolve()
1074             cached_but_changed = (path / "changed").resolve()
1075             uncached.touch()
1076             cached.touch()
1077             cached_but_changed.touch()
1078             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1079             todo, done = black.filter_cached(
1080                 cache, {uncached, cached, cached_but_changed}
1081             )
1082             self.assertEqual(todo, {uncached, cached_but_changed})
1083             self.assertEqual(done, {cached})
1084
1085     def test_write_cache_creates_directory_if_needed(self) -> None:
1086         mode = black.FileMode()
1087         with cache_dir(exists=False) as workspace:
1088             self.assertFalse(workspace.exists())
1089             black.write_cache({}, [], mode)
1090             self.assertTrue(workspace.exists())
1091
1092     @event_loop(close=False)
1093     def test_failed_formatting_does_not_get_cached(self) -> None:
1094         mode = black.FileMode()
1095         with cache_dir() as workspace, patch(
1096             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1097         ):
1098             failing = (workspace / "failing.py").resolve()
1099             with failing.open("w") as fobj:
1100                 fobj.write("not actually python")
1101             clean = (workspace / "clean.py").resolve()
1102             with clean.open("w") as fobj:
1103                 fobj.write('print("hello")\n')
1104             self.invokeBlack([str(workspace)], exit_code=123)
1105             cache = black.read_cache(mode)
1106             self.assertNotIn(failing, cache)
1107             self.assertIn(clean, cache)
1108
1109     def test_write_cache_write_fail(self) -> None:
1110         mode = black.FileMode()
1111         with cache_dir(), patch.object(Path, "open") as mock:
1112             mock.side_effect = OSError
1113             black.write_cache({}, [], mode)
1114
1115     @event_loop(close=False)
1116     def test_check_diff_use_together(self) -> None:
1117         with cache_dir():
1118             # Files which will be reformatted.
1119             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1120             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1121             # Files which will not be reformatted.
1122             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1123             self.invokeBlack([str(src2), "--diff", "--check"])
1124             # Multi file command.
1125             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1126
1127     def test_no_files(self) -> None:
1128         with cache_dir():
1129             # Without an argument, black exits with error code 0.
1130             self.invokeBlack([])
1131
1132     def test_broken_symlink(self) -> None:
1133         with cache_dir() as workspace:
1134             symlink = workspace / "broken_link.py"
1135             try:
1136                 symlink.symlink_to("nonexistent.py")
1137             except OSError as e:
1138                 self.skipTest(f"Can't create symlinks: {e}")
1139             self.invokeBlack([str(workspace.resolve())])
1140
1141     def test_read_cache_line_lengths(self) -> None:
1142         mode = black.FileMode()
1143         short_mode = black.FileMode(line_length=1)
1144         with cache_dir() as workspace:
1145             path = (workspace / "file.py").resolve()
1146             path.touch()
1147             black.write_cache({}, [path], mode)
1148             one = black.read_cache(mode)
1149             self.assertIn(path, one)
1150             two = black.read_cache(short_mode)
1151             self.assertNotIn(path, two)
1152
1153     def test_single_file_force_pyi(self) -> None:
1154         reg_mode = black.FileMode()
1155         pyi_mode = black.FileMode(is_pyi=True)
1156         contents, expected = read_data("force_pyi")
1157         with cache_dir() as workspace:
1158             path = (workspace / "file.py").resolve()
1159             with open(path, "w") as fh:
1160                 fh.write(contents)
1161             self.invokeBlack([str(path), "--pyi"])
1162             with open(path, "r") as fh:
1163                 actual = fh.read()
1164             # verify cache with --pyi is separate
1165             pyi_cache = black.read_cache(pyi_mode)
1166             self.assertIn(path, pyi_cache)
1167             normal_cache = black.read_cache(reg_mode)
1168             self.assertNotIn(path, normal_cache)
1169         self.assertEqual(actual, expected)
1170
1171     @event_loop(close=False)
1172     def test_multi_file_force_pyi(self) -> None:
1173         reg_mode = black.FileMode()
1174         pyi_mode = black.FileMode(is_pyi=True)
1175         contents, expected = read_data("force_pyi")
1176         with cache_dir() as workspace:
1177             paths = [
1178                 (workspace / "file1.py").resolve(),
1179                 (workspace / "file2.py").resolve(),
1180             ]
1181             for path in paths:
1182                 with open(path, "w") as fh:
1183                     fh.write(contents)
1184             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1185             for path in paths:
1186                 with open(path, "r") as fh:
1187                     actual = fh.read()
1188                 self.assertEqual(actual, expected)
1189             # verify cache with --pyi is separate
1190             pyi_cache = black.read_cache(pyi_mode)
1191             normal_cache = black.read_cache(reg_mode)
1192             for path in paths:
1193                 self.assertIn(path, pyi_cache)
1194                 self.assertNotIn(path, normal_cache)
1195
1196     def test_pipe_force_pyi(self) -> None:
1197         source, expected = read_data("force_pyi")
1198         result = CliRunner().invoke(
1199             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1200         )
1201         self.assertEqual(result.exit_code, 0)
1202         actual = result.output
1203         self.assertFormatEqual(actual, expected)
1204
1205     def test_single_file_force_py36(self) -> None:
1206         reg_mode = black.FileMode()
1207         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1208         source, expected = read_data("force_py36")
1209         with cache_dir() as workspace:
1210             path = (workspace / "file.py").resolve()
1211             with open(path, "w") as fh:
1212                 fh.write(source)
1213             self.invokeBlack([str(path), *PY36_ARGS])
1214             with open(path, "r") as fh:
1215                 actual = fh.read()
1216             # verify cache with --target-version is separate
1217             py36_cache = black.read_cache(py36_mode)
1218             self.assertIn(path, py36_cache)
1219             normal_cache = black.read_cache(reg_mode)
1220             self.assertNotIn(path, normal_cache)
1221         self.assertEqual(actual, expected)
1222
1223     @event_loop(close=False)
1224     def test_multi_file_force_py36(self) -> None:
1225         reg_mode = black.FileMode()
1226         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1227         source, expected = read_data("force_py36")
1228         with cache_dir() as workspace:
1229             paths = [
1230                 (workspace / "file1.py").resolve(),
1231                 (workspace / "file2.py").resolve(),
1232             ]
1233             for path in paths:
1234                 with open(path, "w") as fh:
1235                     fh.write(source)
1236             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1237             for path in paths:
1238                 with open(path, "r") as fh:
1239                     actual = fh.read()
1240                 self.assertEqual(actual, expected)
1241             # verify cache with --target-version is separate
1242             pyi_cache = black.read_cache(py36_mode)
1243             normal_cache = black.read_cache(reg_mode)
1244             for path in paths:
1245                 self.assertIn(path, pyi_cache)
1246                 self.assertNotIn(path, normal_cache)
1247
1248     def test_pipe_force_py36(self) -> None:
1249         source, expected = read_data("force_py36")
1250         result = CliRunner().invoke(
1251             black.main,
1252             ["-", "-q", "--target-version=py36"],
1253             input=BytesIO(source.encode("utf8")),
1254         )
1255         self.assertEqual(result.exit_code, 0)
1256         actual = result.output
1257         self.assertFormatEqual(actual, expected)
1258
1259     def test_include_exclude(self) -> None:
1260         path = THIS_DIR / "data" / "include_exclude_tests"
1261         include = re.compile(r"\.pyi?$")
1262         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1263         report = black.Report()
1264         sources: List[Path] = []
1265         expected = [
1266             Path(path / "b/dont_exclude/a.py"),
1267             Path(path / "b/dont_exclude/a.pyi"),
1268         ]
1269         this_abs = THIS_DIR.resolve()
1270         sources.extend(
1271             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1272         )
1273         self.assertEqual(sorted(expected), sorted(sources))
1274
1275     def test_empty_include(self) -> None:
1276         path = THIS_DIR / "data" / "include_exclude_tests"
1277         report = black.Report()
1278         empty = re.compile(r"")
1279         sources: List[Path] = []
1280         expected = [
1281             Path(path / "b/exclude/a.pie"),
1282             Path(path / "b/exclude/a.py"),
1283             Path(path / "b/exclude/a.pyi"),
1284             Path(path / "b/dont_exclude/a.pie"),
1285             Path(path / "b/dont_exclude/a.py"),
1286             Path(path / "b/dont_exclude/a.pyi"),
1287             Path(path / "b/.definitely_exclude/a.pie"),
1288             Path(path / "b/.definitely_exclude/a.py"),
1289             Path(path / "b/.definitely_exclude/a.pyi"),
1290         ]
1291         this_abs = THIS_DIR.resolve()
1292         sources.extend(
1293             black.gen_python_files_in_dir(
1294                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1295             )
1296         )
1297         self.assertEqual(sorted(expected), sorted(sources))
1298
1299     def test_empty_exclude(self) -> None:
1300         path = THIS_DIR / "data" / "include_exclude_tests"
1301         report = black.Report()
1302         empty = re.compile(r"")
1303         sources: List[Path] = []
1304         expected = [
1305             Path(path / "b/dont_exclude/a.py"),
1306             Path(path / "b/dont_exclude/a.pyi"),
1307             Path(path / "b/exclude/a.py"),
1308             Path(path / "b/exclude/a.pyi"),
1309             Path(path / "b/.definitely_exclude/a.py"),
1310             Path(path / "b/.definitely_exclude/a.pyi"),
1311         ]
1312         this_abs = THIS_DIR.resolve()
1313         sources.extend(
1314             black.gen_python_files_in_dir(
1315                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1316             )
1317         )
1318         self.assertEqual(sorted(expected), sorted(sources))
1319
1320     def test_invalid_include_exclude(self) -> None:
1321         for option in ["--include", "--exclude"]:
1322             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1323
1324     def test_preserves_line_endings(self) -> None:
1325         with TemporaryDirectory() as workspace:
1326             test_file = Path(workspace) / "test.py"
1327             for nl in ["\n", "\r\n"]:
1328                 contents = nl.join(["def f(  ):", "    pass"])
1329                 test_file.write_bytes(contents.encode())
1330                 ff(test_file, write_back=black.WriteBack.YES)
1331                 updated_contents: bytes = test_file.read_bytes()
1332                 self.assertIn(nl.encode(), updated_contents)
1333                 if nl == "\n":
1334                     self.assertNotIn(b"\r\n", updated_contents)
1335
1336     def test_preserves_line_endings_via_stdin(self) -> None:
1337         for nl in ["\n", "\r\n"]:
1338             contents = nl.join(["def f(  ):", "    pass"])
1339             runner = BlackRunner()
1340             result = runner.invoke(
1341                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1342             )
1343             self.assertEqual(result.exit_code, 0)
1344             output = runner.stdout_bytes
1345             self.assertIn(nl.encode("utf8"), output)
1346             if nl == "\n":
1347                 self.assertNotIn(b"\r\n", output)
1348
1349     def test_assert_equivalent_different_asts(self) -> None:
1350         with self.assertRaises(AssertionError):
1351             black.assert_equivalent("{}", "None")
1352
1353     def test_symlink_out_of_root_directory(self) -> None:
1354         path = MagicMock()
1355         root = THIS_DIR
1356         child = MagicMock()
1357         include = re.compile(black.DEFAULT_INCLUDES)
1358         exclude = re.compile(black.DEFAULT_EXCLUDES)
1359         report = black.Report()
1360         # `child` should behave like a symlink which resolved path is clearly
1361         # outside of the `root` directory.
1362         path.iterdir.return_value = [child]
1363         child.resolve.return_value = Path("/a/b/c")
1364         child.is_symlink.return_value = True
1365         try:
1366             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1367         except ValueError as ve:
1368             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1369         path.iterdir.assert_called_once()
1370         child.resolve.assert_called_once()
1371         child.is_symlink.assert_called_once()
1372         # `child` should behave like a strange file which resolved path is clearly
1373         # outside of the `root` directory.
1374         child.is_symlink.return_value = False
1375         with self.assertRaises(ValueError):
1376             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1377         path.iterdir.assert_called()
1378         self.assertEqual(path.iterdir.call_count, 2)
1379         child.resolve.assert_called()
1380         self.assertEqual(child.resolve.call_count, 2)
1381         child.is_symlink.assert_called()
1382         self.assertEqual(child.is_symlink.call_count, 2)
1383
1384     def test_shhh_click(self) -> None:
1385         try:
1386             from click import _unicodefun  # type: ignore
1387         except ModuleNotFoundError:
1388             self.skipTest("Incompatible Click version")
1389         if not hasattr(_unicodefun, "_verify_python3_env"):
1390             self.skipTest("Incompatible Click version")
1391         # First, let's see if Click is crashing with a preferred ASCII charset.
1392         with patch("locale.getpreferredencoding") as gpe:
1393             gpe.return_value = "ASCII"
1394             with self.assertRaises(RuntimeError):
1395                 _unicodefun._verify_python3_env()
1396         # Now, let's silence Click...
1397         black.patch_click()
1398         # ...and confirm it's silent.
1399         with patch("locale.getpreferredencoding") as gpe:
1400             gpe.return_value = "ASCII"
1401             try:
1402                 _unicodefun._verify_python3_env()
1403             except RuntimeError as re:
1404                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1405
1406     def test_root_logger_not_used_directly(self) -> None:
1407         def fail(*args: Any, **kwargs: Any) -> None:
1408             self.fail("Record created with root logger")
1409
1410         with patch.multiple(
1411             logging.root,
1412             debug=fail,
1413             info=fail,
1414             warning=fail,
1415             error=fail,
1416             critical=fail,
1417             log=fail,
1418         ):
1419             ff(THIS_FILE)
1420
1421     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1422     @async_test
1423     async def test_blackd_request_needs_formatting(self) -> None:
1424         app = blackd.make_app()
1425         async with TestClient(TestServer(app)) as client:
1426             response = await client.post("/", data=b"print('hello world')")
1427             self.assertEqual(response.status, 200)
1428             self.assertEqual(response.charset, "utf8")
1429             self.assertEqual(await response.read(), b'print("hello world")\n')
1430
1431     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1432     @async_test
1433     async def test_blackd_request_no_change(self) -> None:
1434         app = blackd.make_app()
1435         async with TestClient(TestServer(app)) as client:
1436             response = await client.post("/", data=b'print("hello world")\n')
1437             self.assertEqual(response.status, 204)
1438             self.assertEqual(await response.read(), b"")
1439
1440     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1441     @async_test
1442     async def test_blackd_request_syntax_error(self) -> None:
1443         app = blackd.make_app()
1444         async with TestClient(TestServer(app)) as client:
1445             response = await client.post("/", data=b"what even ( is")
1446             self.assertEqual(response.status, 400)
1447             content = await response.text()
1448             self.assertTrue(
1449                 content.startswith("Cannot parse"),
1450                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1451             )
1452
1453     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1454     @async_test
1455     async def test_blackd_unsupported_version(self) -> None:
1456         app = blackd.make_app()
1457         async with TestClient(TestServer(app)) as client:
1458             response = await client.post(
1459                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1460             )
1461             self.assertEqual(response.status, 501)
1462
1463     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1464     @async_test
1465     async def test_blackd_supported_version(self) -> None:
1466         app = blackd.make_app()
1467         async with TestClient(TestServer(app)) as client:
1468             response = await client.post(
1469                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1470             )
1471             self.assertEqual(response.status, 200)
1472
1473     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1474     @async_test
1475     async def test_blackd_invalid_python_variant(self) -> None:
1476         app = blackd.make_app()
1477         async with TestClient(TestServer(app)) as client:
1478
1479             async def check(header_value: str, expected_status: int = 400) -> None:
1480                 response = await client.post(
1481                     "/",
1482                     data=b"what",
1483                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1484                 )
1485                 self.assertEqual(response.status, expected_status)
1486
1487             await check("lol")
1488             await check("ruby3.5")
1489             await check("pyi3.6")
1490             await check("py1.5")
1491             await check("2.8")
1492             await check("py2.8")
1493             await check("3.0")
1494             await check("pypy3.0")
1495             await check("jython3.4")
1496
1497     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1498     @async_test
1499     async def test_blackd_pyi(self) -> None:
1500         app = blackd.make_app()
1501         async with TestClient(TestServer(app)) as client:
1502             source, expected = read_data("stub.pyi")
1503             response = await client.post(
1504                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1505             )
1506             self.assertEqual(response.status, 200)
1507             self.assertEqual(await response.text(), expected)
1508
1509     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1510     @async_test
1511     async def test_blackd_python_variant(self) -> None:
1512         app = blackd.make_app()
1513         code = (
1514             "def f(\n"
1515             "    and_has_a_bunch_of,\n"
1516             "    very_long_arguments_too,\n"
1517             "    and_lots_of_them_as_well_lol,\n"
1518             "    **and_very_long_keyword_arguments\n"
1519             "):\n"
1520             "    pass\n"
1521         )
1522         async with TestClient(TestServer(app)) as client:
1523
1524             async def check(header_value: str, expected_status: int) -> None:
1525                 response = await client.post(
1526                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1527                 )
1528                 self.assertEqual(response.status, expected_status)
1529
1530             await check("3.6", 200)
1531             await check("py3.6", 200)
1532             await check("3.6,3.7", 200)
1533             await check("3.6,py3.7", 200)
1534
1535             await check("2", 204)
1536             await check("2.7", 204)
1537             await check("py2.7", 204)
1538             await check("3.4", 204)
1539             await check("py3.4", 204)
1540
1541     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1542     @async_test
1543     async def test_blackd_fast(self) -> None:
1544         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1545             app = blackd.make_app()
1546             async with TestClient(TestServer(app)) as client:
1547                 response = await client.post("/", data=b"ur'hello'")
1548                 self.assertEqual(response.status, 500)
1549                 self.assertIn("failed to parse source file", await response.text())
1550                 response = await client.post(
1551                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1552                 )
1553                 self.assertEqual(response.status, 200)
1554
1555     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1556     @async_test
1557     async def test_blackd_line_length(self) -> None:
1558         app = blackd.make_app()
1559         async with TestClient(TestServer(app)) as client:
1560             response = await client.post(
1561                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1562             )
1563             self.assertEqual(response.status, 200)
1564
1565     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1566     @async_test
1567     async def test_blackd_invalid_line_length(self) -> None:
1568         app = blackd.make_app()
1569         async with TestClient(TestServer(app)) as client:
1570             response = await client.post(
1571                 "/",
1572                 data=b'print("hello")\n',
1573                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1574             )
1575             self.assertEqual(response.status, 400)
1576
1577     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1578     def test_blackd_main(self) -> None:
1579         with patch("blackd.web.run_app"):
1580             result = CliRunner().invoke(blackd.main, [])
1581             if result.exception is not None:
1582                 raise result.exception
1583             self.assertEqual(result.exit_code, 0)
1584
1585
1586 if __name__ == "__main__":
1587     unittest.main(module="test_black")