]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add to changelog
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments7(self) -> None:
393         source, expected = read_data("comments7")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_cantfit(self) -> None:
401         source, expected = read_data("cantfit")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_import_spacing(self) -> None:
409         source, expected = read_data("import_spacing")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_composition(self) -> None:
417         source, expected = read_data("composition")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_empty_lines(self) -> None:
425         source, expected = read_data("empty_lines")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_string_prefixes(self) -> None:
433         source, expected = read_data("string_prefixes")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_numeric_literals(self) -> None:
441         source, expected = read_data("numeric_literals")
442         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
443         actual = fs(source, mode=mode)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, mode)
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_numeric_literals_ignoring_underscores(self) -> None:
450         source, expected = read_data("numeric_literals_skip_underscores")
451         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
452         actual = fs(source, mode=mode)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, mode)
456
457     @patch("black.dump_to_file", dump_to_stderr)
458     def test_numeric_literals_py2(self) -> None:
459         source, expected = read_data("numeric_literals_py2")
460         actual = fs(source)
461         self.assertFormatEqual(expected, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_python2(self) -> None:
466         source, expected = read_data("python2")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         # black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2_print_function(self) -> None:
474         source, expected = read_data("python2_print_function")
475         mode = black.FileMode(target_versions={TargetVersion.PY27})
476         actual = fs(source, mode=mode)
477         self.assertFormatEqual(expected, actual)
478         black.assert_stable(source, actual, mode)
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python2_unicode_literals(self) -> None:
482         source, expected = read_data("python2_unicode_literals")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_stable(source, actual, black.FileMode())
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_stub(self) -> None:
489         mode = black.FileMode(is_pyi=True)
490         source, expected = read_data("stub.pyi")
491         actual = fs(source, mode=mode)
492         self.assertFormatEqual(expected, actual)
493         black.assert_stable(source, actual, mode)
494
495     @patch("black.dump_to_file", dump_to_stderr)
496     def test_python37(self) -> None:
497         source, expected = read_data("python37")
498         actual = fs(source)
499         self.assertFormatEqual(expected, actual)
500         major, minor = sys.version_info[:2]
501         if major > 3 or (major == 3 and minor >= 7):
502             black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_fmtonoff(self) -> None:
507         source, expected = read_data("fmtonoff")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_fmtonoff2(self) -> None:
515         source, expected = read_data("fmtonoff2")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_remove_empty_parentheses_after_class(self) -> None:
523         source, expected = read_data("class_blank_parentheses")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_new_line_between_class_and_code(self) -> None:
531         source, expected = read_data("class_methods_new_line")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, black.FileMode())
536
537     @patch("black.dump_to_file", dump_to_stderr)
538     def test_bracket_match(self) -> None:
539         source, expected = read_data("bracketmatch")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         black.assert_equivalent(source, actual)
543         black.assert_stable(source, actual, black.FileMode())
544
545     @patch("black.dump_to_file", dump_to_stderr)
546     def test_tuple_assign(self) -> None:
547         source, expected = read_data("tupleassign")
548         actual = fs(source)
549         self.assertFormatEqual(expected, actual)
550         black.assert_equivalent(source, actual)
551         black.assert_stable(source, actual, black.FileMode())
552
553     def test_tab_comment_indentation(self) -> None:
554         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
555         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
556         self.assertFormatEqual(contents_spc, fs(contents_spc))
557         self.assertFormatEqual(contents_spc, fs(contents_tab))
558
559         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
560         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
561         self.assertFormatEqual(contents_spc, fs(contents_spc))
562         self.assertFormatEqual(contents_spc, fs(contents_tab))
563
564         # mixed tabs and spaces (valid Python 2 code)
565         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
566         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
567         self.assertFormatEqual(contents_spc, fs(contents_spc))
568         self.assertFormatEqual(contents_spc, fs(contents_tab))
569
570         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
571         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
572         self.assertFormatEqual(contents_spc, fs(contents_spc))
573         self.assertFormatEqual(contents_spc, fs(contents_tab))
574
575     def test_report_verbose(self) -> None:
576         report = black.Report(verbose=True)
577         out_lines = []
578         err_lines = []
579
580         def out(msg: str, **kwargs: Any) -> None:
581             out_lines.append(msg)
582
583         def err(msg: str, **kwargs: Any) -> None:
584             err_lines.append(msg)
585
586         with patch("black.out", out), patch("black.err", err):
587             report.done(Path("f1"), black.Changed.NO)
588             self.assertEqual(len(out_lines), 1)
589             self.assertEqual(len(err_lines), 0)
590             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
591             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
592             self.assertEqual(report.return_code, 0)
593             report.done(Path("f2"), black.Changed.YES)
594             self.assertEqual(len(out_lines), 2)
595             self.assertEqual(len(err_lines), 0)
596             self.assertEqual(out_lines[-1], "reformatted f2")
597             self.assertEqual(
598                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
599             )
600             report.done(Path("f3"), black.Changed.CACHED)
601             self.assertEqual(len(out_lines), 3)
602             self.assertEqual(len(err_lines), 0)
603             self.assertEqual(
604                 out_lines[-1], "f3 wasn't modified on disk since last run."
605             )
606             self.assertEqual(
607                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
608             )
609             self.assertEqual(report.return_code, 0)
610             report.check = True
611             self.assertEqual(report.return_code, 1)
612             report.check = False
613             report.failed(Path("e1"), "boom")
614             self.assertEqual(len(out_lines), 3)
615             self.assertEqual(len(err_lines), 1)
616             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
617             self.assertEqual(
618                 unstyle(str(report)),
619                 "1 file reformatted, 2 files left unchanged, "
620                 "1 file failed to reformat.",
621             )
622             self.assertEqual(report.return_code, 123)
623             report.done(Path("f3"), black.Changed.YES)
624             self.assertEqual(len(out_lines), 4)
625             self.assertEqual(len(err_lines), 1)
626             self.assertEqual(out_lines[-1], "reformatted f3")
627             self.assertEqual(
628                 unstyle(str(report)),
629                 "2 files reformatted, 2 files left unchanged, "
630                 "1 file failed to reformat.",
631             )
632             self.assertEqual(report.return_code, 123)
633             report.failed(Path("e2"), "boom")
634             self.assertEqual(len(out_lines), 4)
635             self.assertEqual(len(err_lines), 2)
636             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
637             self.assertEqual(
638                 unstyle(str(report)),
639                 "2 files reformatted, 2 files left unchanged, "
640                 "2 files failed to reformat.",
641             )
642             self.assertEqual(report.return_code, 123)
643             report.path_ignored(Path("wat"), "no match")
644             self.assertEqual(len(out_lines), 5)
645             self.assertEqual(len(err_lines), 2)
646             self.assertEqual(out_lines[-1], "wat ignored: no match")
647             self.assertEqual(
648                 unstyle(str(report)),
649                 "2 files reformatted, 2 files left unchanged, "
650                 "2 files failed to reformat.",
651             )
652             self.assertEqual(report.return_code, 123)
653             report.done(Path("f4"), black.Changed.NO)
654             self.assertEqual(len(out_lines), 6)
655             self.assertEqual(len(err_lines), 2)
656             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
657             self.assertEqual(
658                 unstyle(str(report)),
659                 "2 files reformatted, 3 files left unchanged, "
660                 "2 files failed to reformat.",
661             )
662             self.assertEqual(report.return_code, 123)
663             report.check = True
664             self.assertEqual(
665                 unstyle(str(report)),
666                 "2 files would be reformatted, 3 files would be left unchanged, "
667                 "2 files would fail to reformat.",
668             )
669
670     def test_report_quiet(self) -> None:
671         report = black.Report(quiet=True)
672         out_lines = []
673         err_lines = []
674
675         def out(msg: str, **kwargs: Any) -> None:
676             out_lines.append(msg)
677
678         def err(msg: str, **kwargs: Any) -> None:
679             err_lines.append(msg)
680
681         with patch("black.out", out), patch("black.err", err):
682             report.done(Path("f1"), black.Changed.NO)
683             self.assertEqual(len(out_lines), 0)
684             self.assertEqual(len(err_lines), 0)
685             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
686             self.assertEqual(report.return_code, 0)
687             report.done(Path("f2"), black.Changed.YES)
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 0)
690             self.assertEqual(
691                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
692             )
693             report.done(Path("f3"), black.Changed.CACHED)
694             self.assertEqual(len(out_lines), 0)
695             self.assertEqual(len(err_lines), 0)
696             self.assertEqual(
697                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
698             )
699             self.assertEqual(report.return_code, 0)
700             report.check = True
701             self.assertEqual(report.return_code, 1)
702             report.check = False
703             report.failed(Path("e1"), "boom")
704             self.assertEqual(len(out_lines), 0)
705             self.assertEqual(len(err_lines), 1)
706             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
707             self.assertEqual(
708                 unstyle(str(report)),
709                 "1 file reformatted, 2 files left unchanged, "
710                 "1 file failed to reformat.",
711             )
712             self.assertEqual(report.return_code, 123)
713             report.done(Path("f3"), black.Changed.YES)
714             self.assertEqual(len(out_lines), 0)
715             self.assertEqual(len(err_lines), 1)
716             self.assertEqual(
717                 unstyle(str(report)),
718                 "2 files reformatted, 2 files left unchanged, "
719                 "1 file failed to reformat.",
720             )
721             self.assertEqual(report.return_code, 123)
722             report.failed(Path("e2"), "boom")
723             self.assertEqual(len(out_lines), 0)
724             self.assertEqual(len(err_lines), 2)
725             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
726             self.assertEqual(
727                 unstyle(str(report)),
728                 "2 files reformatted, 2 files left unchanged, "
729                 "2 files failed to reformat.",
730             )
731             self.assertEqual(report.return_code, 123)
732             report.path_ignored(Path("wat"), "no match")
733             self.assertEqual(len(out_lines), 0)
734             self.assertEqual(len(err_lines), 2)
735             self.assertEqual(
736                 unstyle(str(report)),
737                 "2 files reformatted, 2 files left unchanged, "
738                 "2 files failed to reformat.",
739             )
740             self.assertEqual(report.return_code, 123)
741             report.done(Path("f4"), black.Changed.NO)
742             self.assertEqual(len(out_lines), 0)
743             self.assertEqual(len(err_lines), 2)
744             self.assertEqual(
745                 unstyle(str(report)),
746                 "2 files reformatted, 3 files left unchanged, "
747                 "2 files failed to reformat.",
748             )
749             self.assertEqual(report.return_code, 123)
750             report.check = True
751             self.assertEqual(
752                 unstyle(str(report)),
753                 "2 files would be reformatted, 3 files would be left unchanged, "
754                 "2 files would fail to reformat.",
755             )
756
757     def test_report_normal(self) -> None:
758         report = black.Report()
759         out_lines = []
760         err_lines = []
761
762         def out(msg: str, **kwargs: Any) -> None:
763             out_lines.append(msg)
764
765         def err(msg: str, **kwargs: Any) -> None:
766             err_lines.append(msg)
767
768         with patch("black.out", out), patch("black.err", err):
769             report.done(Path("f1"), black.Changed.NO)
770             self.assertEqual(len(out_lines), 0)
771             self.assertEqual(len(err_lines), 0)
772             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
773             self.assertEqual(report.return_code, 0)
774             report.done(Path("f2"), black.Changed.YES)
775             self.assertEqual(len(out_lines), 1)
776             self.assertEqual(len(err_lines), 0)
777             self.assertEqual(out_lines[-1], "reformatted f2")
778             self.assertEqual(
779                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
780             )
781             report.done(Path("f3"), black.Changed.CACHED)
782             self.assertEqual(len(out_lines), 1)
783             self.assertEqual(len(err_lines), 0)
784             self.assertEqual(out_lines[-1], "reformatted f2")
785             self.assertEqual(
786                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
787             )
788             self.assertEqual(report.return_code, 0)
789             report.check = True
790             self.assertEqual(report.return_code, 1)
791             report.check = False
792             report.failed(Path("e1"), "boom")
793             self.assertEqual(len(out_lines), 1)
794             self.assertEqual(len(err_lines), 1)
795             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
796             self.assertEqual(
797                 unstyle(str(report)),
798                 "1 file reformatted, 2 files left unchanged, "
799                 "1 file failed to reformat.",
800             )
801             self.assertEqual(report.return_code, 123)
802             report.done(Path("f3"), black.Changed.YES)
803             self.assertEqual(len(out_lines), 2)
804             self.assertEqual(len(err_lines), 1)
805             self.assertEqual(out_lines[-1], "reformatted f3")
806             self.assertEqual(
807                 unstyle(str(report)),
808                 "2 files reformatted, 2 files left unchanged, "
809                 "1 file failed to reformat.",
810             )
811             self.assertEqual(report.return_code, 123)
812             report.failed(Path("e2"), "boom")
813             self.assertEqual(len(out_lines), 2)
814             self.assertEqual(len(err_lines), 2)
815             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
816             self.assertEqual(
817                 unstyle(str(report)),
818                 "2 files reformatted, 2 files left unchanged, "
819                 "2 files failed to reformat.",
820             )
821             self.assertEqual(report.return_code, 123)
822             report.path_ignored(Path("wat"), "no match")
823             self.assertEqual(len(out_lines), 2)
824             self.assertEqual(len(err_lines), 2)
825             self.assertEqual(
826                 unstyle(str(report)),
827                 "2 files reformatted, 2 files left unchanged, "
828                 "2 files failed to reformat.",
829             )
830             self.assertEqual(report.return_code, 123)
831             report.done(Path("f4"), black.Changed.NO)
832             self.assertEqual(len(out_lines), 2)
833             self.assertEqual(len(err_lines), 2)
834             self.assertEqual(
835                 unstyle(str(report)),
836                 "2 files reformatted, 3 files left unchanged, "
837                 "2 files failed to reformat.",
838             )
839             self.assertEqual(report.return_code, 123)
840             report.check = True
841             self.assertEqual(
842                 unstyle(str(report)),
843                 "2 files would be reformatted, 3 files would be left unchanged, "
844                 "2 files would fail to reformat.",
845             )
846
847     def test_lib2to3_parse(self) -> None:
848         with self.assertRaises(black.InvalidInput):
849             black.lib2to3_parse("invalid syntax")
850
851         straddling = "x + y"
852         black.lib2to3_parse(straddling)
853         black.lib2to3_parse(straddling, {TargetVersion.PY27})
854         black.lib2to3_parse(straddling, {TargetVersion.PY36})
855         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
856
857         py2_only = "print x"
858         black.lib2to3_parse(py2_only)
859         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
860         with self.assertRaises(black.InvalidInput):
861             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
862         with self.assertRaises(black.InvalidInput):
863             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
864
865         py3_only = "exec(x, end=y)"
866         black.lib2to3_parse(py3_only)
867         with self.assertRaises(black.InvalidInput):
868             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
869         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
870         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
871
872     def test_get_features_used(self) -> None:
873         node = black.lib2to3_parse("def f(*, arg): ...\n")
874         self.assertEqual(black.get_features_used(node), set())
875         node = black.lib2to3_parse("def f(*, arg,): ...\n")
876         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
877         node = black.lib2to3_parse("f(*arg,)\n")
878         self.assertEqual(
879             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
880         )
881         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
882         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
883         node = black.lib2to3_parse("123_456\n")
884         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
885         node = black.lib2to3_parse("123456\n")
886         self.assertEqual(black.get_features_used(node), set())
887         source, expected = read_data("function")
888         node = black.lib2to3_parse(source)
889         expected_features = {
890             Feature.TRAILING_COMMA_IN_CALL,
891             Feature.TRAILING_COMMA_IN_DEF,
892             Feature.F_STRINGS,
893         }
894         self.assertEqual(black.get_features_used(node), expected_features)
895         node = black.lib2to3_parse(expected)
896         self.assertEqual(black.get_features_used(node), expected_features)
897         source, expected = read_data("expression")
898         node = black.lib2to3_parse(source)
899         self.assertEqual(black.get_features_used(node), set())
900         node = black.lib2to3_parse(expected)
901         self.assertEqual(black.get_features_used(node), set())
902
903     def test_get_future_imports(self) -> None:
904         node = black.lib2to3_parse("\n")
905         self.assertEqual(set(), black.get_future_imports(node))
906         node = black.lib2to3_parse("from __future__ import black\n")
907         self.assertEqual({"black"}, black.get_future_imports(node))
908         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
909         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
910         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
911         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
912         node = black.lib2to3_parse(
913             "from __future__ import multiple\nfrom __future__ import imports\n"
914         )
915         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
916         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
917         self.assertEqual({"black"}, black.get_future_imports(node))
918         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
919         self.assertEqual({"black"}, black.get_future_imports(node))
920         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
921         self.assertEqual(set(), black.get_future_imports(node))
922         node = black.lib2to3_parse("from some.module import black\n")
923         self.assertEqual(set(), black.get_future_imports(node))
924         node = black.lib2to3_parse(
925             "from __future__ import unicode_literals as _unicode_literals"
926         )
927         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
928         node = black.lib2to3_parse(
929             "from __future__ import unicode_literals as _lol, print"
930         )
931         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
932
933     def test_debug_visitor(self) -> None:
934         source, _ = read_data("debug_visitor.py")
935         expected, _ = read_data("debug_visitor.out")
936         out_lines = []
937         err_lines = []
938
939         def out(msg: str, **kwargs: Any) -> None:
940             out_lines.append(msg)
941
942         def err(msg: str, **kwargs: Any) -> None:
943             err_lines.append(msg)
944
945         with patch("black.out", out), patch("black.err", err):
946             black.DebugVisitor.show(source)
947         actual = "\n".join(out_lines) + "\n"
948         log_name = ""
949         if expected != actual:
950             log_name = black.dump_to_file(*out_lines)
951         self.assertEqual(
952             expected,
953             actual,
954             f"AST print out is different. Actual version dumped to {log_name}",
955         )
956
957     def test_format_file_contents(self) -> None:
958         empty = ""
959         mode = black.FileMode()
960         with self.assertRaises(black.NothingChanged):
961             black.format_file_contents(empty, mode=mode, fast=False)
962         just_nl = "\n"
963         with self.assertRaises(black.NothingChanged):
964             black.format_file_contents(just_nl, mode=mode, fast=False)
965         same = "l = [1, 2, 3]\n"
966         with self.assertRaises(black.NothingChanged):
967             black.format_file_contents(same, mode=mode, fast=False)
968         different = "l = [1,2,3]"
969         expected = same
970         actual = black.format_file_contents(different, mode=mode, fast=False)
971         self.assertEqual(expected, actual)
972         invalid = "return if you can"
973         with self.assertRaises(black.InvalidInput) as e:
974             black.format_file_contents(invalid, mode=mode, fast=False)
975         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
976
977     def test_endmarker(self) -> None:
978         n = black.lib2to3_parse("\n")
979         self.assertEqual(n.type, black.syms.file_input)
980         self.assertEqual(len(n.children), 1)
981         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
982
983     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
984     def test_assertFormatEqual(self) -> None:
985         out_lines = []
986         err_lines = []
987
988         def out(msg: str, **kwargs: Any) -> None:
989             out_lines.append(msg)
990
991         def err(msg: str, **kwargs: Any) -> None:
992             err_lines.append(msg)
993
994         with patch("black.out", out), patch("black.err", err):
995             with self.assertRaises(AssertionError):
996                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
997
998         out_str = "".join(out_lines)
999         self.assertTrue("Expected tree:" in out_str)
1000         self.assertTrue("Actual tree:" in out_str)
1001         self.assertEqual("".join(err_lines), "")
1002
1003     def test_cache_broken_file(self) -> None:
1004         mode = black.FileMode()
1005         with cache_dir() as workspace:
1006             cache_file = black.get_cache_file(mode)
1007             with cache_file.open("w") as fobj:
1008                 fobj.write("this is not a pickle")
1009             self.assertEqual(black.read_cache(mode), {})
1010             src = (workspace / "test.py").resolve()
1011             with src.open("w") as fobj:
1012                 fobj.write("print('hello')")
1013             self.invokeBlack([str(src)])
1014             cache = black.read_cache(mode)
1015             self.assertIn(src, cache)
1016
1017     def test_cache_single_file_already_cached(self) -> None:
1018         mode = black.FileMode()
1019         with cache_dir() as workspace:
1020             src = (workspace / "test.py").resolve()
1021             with src.open("w") as fobj:
1022                 fobj.write("print('hello')")
1023             black.write_cache({}, [src], mode)
1024             self.invokeBlack([str(src)])
1025             with src.open("r") as fobj:
1026                 self.assertEqual(fobj.read(), "print('hello')")
1027
1028     @event_loop(close=False)
1029     def test_cache_multiple_files(self) -> None:
1030         mode = black.FileMode()
1031         with cache_dir() as workspace, patch(
1032             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1033         ):
1034             one = (workspace / "one.py").resolve()
1035             with one.open("w") as fobj:
1036                 fobj.write("print('hello')")
1037             two = (workspace / "two.py").resolve()
1038             with two.open("w") as fobj:
1039                 fobj.write("print('hello')")
1040             black.write_cache({}, [one], mode)
1041             self.invokeBlack([str(workspace)])
1042             with one.open("r") as fobj:
1043                 self.assertEqual(fobj.read(), "print('hello')")
1044             with two.open("r") as fobj:
1045                 self.assertEqual(fobj.read(), 'print("hello")\n')
1046             cache = black.read_cache(mode)
1047             self.assertIn(one, cache)
1048             self.assertIn(two, cache)
1049
1050     def test_no_cache_when_writeback_diff(self) -> None:
1051         mode = black.FileMode()
1052         with cache_dir() as workspace:
1053             src = (workspace / "test.py").resolve()
1054             with src.open("w") as fobj:
1055                 fobj.write("print('hello')")
1056             self.invokeBlack([str(src), "--diff"])
1057             cache_file = black.get_cache_file(mode)
1058             self.assertFalse(cache_file.exists())
1059
1060     def test_no_cache_when_stdin(self) -> None:
1061         mode = black.FileMode()
1062         with cache_dir():
1063             result = CliRunner().invoke(
1064                 black.main, ["-"], input=BytesIO(b"print('hello')")
1065             )
1066             self.assertEqual(result.exit_code, 0)
1067             cache_file = black.get_cache_file(mode)
1068             self.assertFalse(cache_file.exists())
1069
1070     def test_read_cache_no_cachefile(self) -> None:
1071         mode = black.FileMode()
1072         with cache_dir():
1073             self.assertEqual(black.read_cache(mode), {})
1074
1075     def test_write_cache_read_cache(self) -> None:
1076         mode = black.FileMode()
1077         with cache_dir() as workspace:
1078             src = (workspace / "test.py").resolve()
1079             src.touch()
1080             black.write_cache({}, [src], mode)
1081             cache = black.read_cache(mode)
1082             self.assertIn(src, cache)
1083             self.assertEqual(cache[src], black.get_cache_info(src))
1084
1085     def test_filter_cached(self) -> None:
1086         with TemporaryDirectory() as workspace:
1087             path = Path(workspace)
1088             uncached = (path / "uncached").resolve()
1089             cached = (path / "cached").resolve()
1090             cached_but_changed = (path / "changed").resolve()
1091             uncached.touch()
1092             cached.touch()
1093             cached_but_changed.touch()
1094             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1095             todo, done = black.filter_cached(
1096                 cache, {uncached, cached, cached_but_changed}
1097             )
1098             self.assertEqual(todo, {uncached, cached_but_changed})
1099             self.assertEqual(done, {cached})
1100
1101     def test_write_cache_creates_directory_if_needed(self) -> None:
1102         mode = black.FileMode()
1103         with cache_dir(exists=False) as workspace:
1104             self.assertFalse(workspace.exists())
1105             black.write_cache({}, [], mode)
1106             self.assertTrue(workspace.exists())
1107
1108     @event_loop(close=False)
1109     def test_failed_formatting_does_not_get_cached(self) -> None:
1110         mode = black.FileMode()
1111         with cache_dir() as workspace, patch(
1112             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1113         ):
1114             failing = (workspace / "failing.py").resolve()
1115             with failing.open("w") as fobj:
1116                 fobj.write("not actually python")
1117             clean = (workspace / "clean.py").resolve()
1118             with clean.open("w") as fobj:
1119                 fobj.write('print("hello")\n')
1120             self.invokeBlack([str(workspace)], exit_code=123)
1121             cache = black.read_cache(mode)
1122             self.assertNotIn(failing, cache)
1123             self.assertIn(clean, cache)
1124
1125     def test_write_cache_write_fail(self) -> None:
1126         mode = black.FileMode()
1127         with cache_dir(), patch.object(Path, "open") as mock:
1128             mock.side_effect = OSError
1129             black.write_cache({}, [], mode)
1130
1131     @event_loop(close=False)
1132     def test_check_diff_use_together(self) -> None:
1133         with cache_dir():
1134             # Files which will be reformatted.
1135             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1136             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1137             # Files which will not be reformatted.
1138             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1139             self.invokeBlack([str(src2), "--diff", "--check"])
1140             # Multi file command.
1141             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1142
1143     def test_no_files(self) -> None:
1144         with cache_dir():
1145             # Without an argument, black exits with error code 0.
1146             self.invokeBlack([])
1147
1148     def test_broken_symlink(self) -> None:
1149         with cache_dir() as workspace:
1150             symlink = workspace / "broken_link.py"
1151             try:
1152                 symlink.symlink_to("nonexistent.py")
1153             except OSError as e:
1154                 self.skipTest(f"Can't create symlinks: {e}")
1155             self.invokeBlack([str(workspace.resolve())])
1156
1157     def test_read_cache_line_lengths(self) -> None:
1158         mode = black.FileMode()
1159         short_mode = black.FileMode(line_length=1)
1160         with cache_dir() as workspace:
1161             path = (workspace / "file.py").resolve()
1162             path.touch()
1163             black.write_cache({}, [path], mode)
1164             one = black.read_cache(mode)
1165             self.assertIn(path, one)
1166             two = black.read_cache(short_mode)
1167             self.assertNotIn(path, two)
1168
1169     def test_single_file_force_pyi(self) -> None:
1170         reg_mode = black.FileMode()
1171         pyi_mode = black.FileMode(is_pyi=True)
1172         contents, expected = read_data("force_pyi")
1173         with cache_dir() as workspace:
1174             path = (workspace / "file.py").resolve()
1175             with open(path, "w") as fh:
1176                 fh.write(contents)
1177             self.invokeBlack([str(path), "--pyi"])
1178             with open(path, "r") as fh:
1179                 actual = fh.read()
1180             # verify cache with --pyi is separate
1181             pyi_cache = black.read_cache(pyi_mode)
1182             self.assertIn(path, pyi_cache)
1183             normal_cache = black.read_cache(reg_mode)
1184             self.assertNotIn(path, normal_cache)
1185         self.assertEqual(actual, expected)
1186
1187     @event_loop(close=False)
1188     def test_multi_file_force_pyi(self) -> None:
1189         reg_mode = black.FileMode()
1190         pyi_mode = black.FileMode(is_pyi=True)
1191         contents, expected = read_data("force_pyi")
1192         with cache_dir() as workspace:
1193             paths = [
1194                 (workspace / "file1.py").resolve(),
1195                 (workspace / "file2.py").resolve(),
1196             ]
1197             for path in paths:
1198                 with open(path, "w") as fh:
1199                     fh.write(contents)
1200             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1201             for path in paths:
1202                 with open(path, "r") as fh:
1203                     actual = fh.read()
1204                 self.assertEqual(actual, expected)
1205             # verify cache with --pyi is separate
1206             pyi_cache = black.read_cache(pyi_mode)
1207             normal_cache = black.read_cache(reg_mode)
1208             for path in paths:
1209                 self.assertIn(path, pyi_cache)
1210                 self.assertNotIn(path, normal_cache)
1211
1212     def test_pipe_force_pyi(self) -> None:
1213         source, expected = read_data("force_pyi")
1214         result = CliRunner().invoke(
1215             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1216         )
1217         self.assertEqual(result.exit_code, 0)
1218         actual = result.output
1219         self.assertFormatEqual(actual, expected)
1220
1221     def test_single_file_force_py36(self) -> None:
1222         reg_mode = black.FileMode()
1223         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1224         source, expected = read_data("force_py36")
1225         with cache_dir() as workspace:
1226             path = (workspace / "file.py").resolve()
1227             with open(path, "w") as fh:
1228                 fh.write(source)
1229             self.invokeBlack([str(path), *PY36_ARGS])
1230             with open(path, "r") as fh:
1231                 actual = fh.read()
1232             # verify cache with --target-version is separate
1233             py36_cache = black.read_cache(py36_mode)
1234             self.assertIn(path, py36_cache)
1235             normal_cache = black.read_cache(reg_mode)
1236             self.assertNotIn(path, normal_cache)
1237         self.assertEqual(actual, expected)
1238
1239     @event_loop(close=False)
1240     def test_multi_file_force_py36(self) -> None:
1241         reg_mode = black.FileMode()
1242         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1243         source, expected = read_data("force_py36")
1244         with cache_dir() as workspace:
1245             paths = [
1246                 (workspace / "file1.py").resolve(),
1247                 (workspace / "file2.py").resolve(),
1248             ]
1249             for path in paths:
1250                 with open(path, "w") as fh:
1251                     fh.write(source)
1252             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1253             for path in paths:
1254                 with open(path, "r") as fh:
1255                     actual = fh.read()
1256                 self.assertEqual(actual, expected)
1257             # verify cache with --target-version is separate
1258             pyi_cache = black.read_cache(py36_mode)
1259             normal_cache = black.read_cache(reg_mode)
1260             for path in paths:
1261                 self.assertIn(path, pyi_cache)
1262                 self.assertNotIn(path, normal_cache)
1263
1264     def test_pipe_force_py36(self) -> None:
1265         source, expected = read_data("force_py36")
1266         result = CliRunner().invoke(
1267             black.main,
1268             ["-", "-q", "--target-version=py36"],
1269             input=BytesIO(source.encode("utf8")),
1270         )
1271         self.assertEqual(result.exit_code, 0)
1272         actual = result.output
1273         self.assertFormatEqual(actual, expected)
1274
1275     def test_include_exclude(self) -> None:
1276         path = THIS_DIR / "data" / "include_exclude_tests"
1277         include = re.compile(r"\.pyi?$")
1278         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1279         report = black.Report()
1280         sources: List[Path] = []
1281         expected = [
1282             Path(path / "b/dont_exclude/a.py"),
1283             Path(path / "b/dont_exclude/a.pyi"),
1284         ]
1285         this_abs = THIS_DIR.resolve()
1286         sources.extend(
1287             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1288         )
1289         self.assertEqual(sorted(expected), sorted(sources))
1290
1291     def test_empty_include(self) -> None:
1292         path = THIS_DIR / "data" / "include_exclude_tests"
1293         report = black.Report()
1294         empty = re.compile(r"")
1295         sources: List[Path] = []
1296         expected = [
1297             Path(path / "b/exclude/a.pie"),
1298             Path(path / "b/exclude/a.py"),
1299             Path(path / "b/exclude/a.pyi"),
1300             Path(path / "b/dont_exclude/a.pie"),
1301             Path(path / "b/dont_exclude/a.py"),
1302             Path(path / "b/dont_exclude/a.pyi"),
1303             Path(path / "b/.definitely_exclude/a.pie"),
1304             Path(path / "b/.definitely_exclude/a.py"),
1305             Path(path / "b/.definitely_exclude/a.pyi"),
1306         ]
1307         this_abs = THIS_DIR.resolve()
1308         sources.extend(
1309             black.gen_python_files_in_dir(
1310                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1311             )
1312         )
1313         self.assertEqual(sorted(expected), sorted(sources))
1314
1315     def test_empty_exclude(self) -> None:
1316         path = THIS_DIR / "data" / "include_exclude_tests"
1317         report = black.Report()
1318         empty = re.compile(r"")
1319         sources: List[Path] = []
1320         expected = [
1321             Path(path / "b/dont_exclude/a.py"),
1322             Path(path / "b/dont_exclude/a.pyi"),
1323             Path(path / "b/exclude/a.py"),
1324             Path(path / "b/exclude/a.pyi"),
1325             Path(path / "b/.definitely_exclude/a.py"),
1326             Path(path / "b/.definitely_exclude/a.pyi"),
1327         ]
1328         this_abs = THIS_DIR.resolve()
1329         sources.extend(
1330             black.gen_python_files_in_dir(
1331                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1332             )
1333         )
1334         self.assertEqual(sorted(expected), sorted(sources))
1335
1336     def test_invalid_include_exclude(self) -> None:
1337         for option in ["--include", "--exclude"]:
1338             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1339
1340     def test_preserves_line_endings(self) -> None:
1341         with TemporaryDirectory() as workspace:
1342             test_file = Path(workspace) / "test.py"
1343             for nl in ["\n", "\r\n"]:
1344                 contents = nl.join(["def f(  ):", "    pass"])
1345                 test_file.write_bytes(contents.encode())
1346                 ff(test_file, write_back=black.WriteBack.YES)
1347                 updated_contents: bytes = test_file.read_bytes()
1348                 self.assertIn(nl.encode(), updated_contents)
1349                 if nl == "\n":
1350                     self.assertNotIn(b"\r\n", updated_contents)
1351
1352     def test_preserves_line_endings_via_stdin(self) -> None:
1353         for nl in ["\n", "\r\n"]:
1354             contents = nl.join(["def f(  ):", "    pass"])
1355             runner = BlackRunner()
1356             result = runner.invoke(
1357                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1358             )
1359             self.assertEqual(result.exit_code, 0)
1360             output = runner.stdout_bytes
1361             self.assertIn(nl.encode("utf8"), output)
1362             if nl == "\n":
1363                 self.assertNotIn(b"\r\n", output)
1364
1365     def test_assert_equivalent_different_asts(self) -> None:
1366         with self.assertRaises(AssertionError):
1367             black.assert_equivalent("{}", "None")
1368
1369     def test_symlink_out_of_root_directory(self) -> None:
1370         path = MagicMock()
1371         root = THIS_DIR
1372         child = MagicMock()
1373         include = re.compile(black.DEFAULT_INCLUDES)
1374         exclude = re.compile(black.DEFAULT_EXCLUDES)
1375         report = black.Report()
1376         # `child` should behave like a symlink which resolved path is clearly
1377         # outside of the `root` directory.
1378         path.iterdir.return_value = [child]
1379         child.resolve.return_value = Path("/a/b/c")
1380         child.is_symlink.return_value = True
1381         try:
1382             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1383         except ValueError as ve:
1384             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1385         path.iterdir.assert_called_once()
1386         child.resolve.assert_called_once()
1387         child.is_symlink.assert_called_once()
1388         # `child` should behave like a strange file which resolved path is clearly
1389         # outside of the `root` directory.
1390         child.is_symlink.return_value = False
1391         with self.assertRaises(ValueError):
1392             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1393         path.iterdir.assert_called()
1394         self.assertEqual(path.iterdir.call_count, 2)
1395         child.resolve.assert_called()
1396         self.assertEqual(child.resolve.call_count, 2)
1397         child.is_symlink.assert_called()
1398         self.assertEqual(child.is_symlink.call_count, 2)
1399
1400     def test_shhh_click(self) -> None:
1401         try:
1402             from click import _unicodefun  # type: ignore
1403         except ModuleNotFoundError:
1404             self.skipTest("Incompatible Click version")
1405         if not hasattr(_unicodefun, "_verify_python3_env"):
1406             self.skipTest("Incompatible Click version")
1407         # First, let's see if Click is crashing with a preferred ASCII charset.
1408         with patch("locale.getpreferredencoding") as gpe:
1409             gpe.return_value = "ASCII"
1410             with self.assertRaises(RuntimeError):
1411                 _unicodefun._verify_python3_env()
1412         # Now, let's silence Click...
1413         black.patch_click()
1414         # ...and confirm it's silent.
1415         with patch("locale.getpreferredencoding") as gpe:
1416             gpe.return_value = "ASCII"
1417             try:
1418                 _unicodefun._verify_python3_env()
1419             except RuntimeError as re:
1420                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1421
1422     def test_root_logger_not_used_directly(self) -> None:
1423         def fail(*args: Any, **kwargs: Any) -> None:
1424             self.fail("Record created with root logger")
1425
1426         with patch.multiple(
1427             logging.root,
1428             debug=fail,
1429             info=fail,
1430             warning=fail,
1431             error=fail,
1432             critical=fail,
1433             log=fail,
1434         ):
1435             ff(THIS_FILE)
1436
1437     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1438     @async_test
1439     async def test_blackd_request_needs_formatting(self) -> None:
1440         app = blackd.make_app()
1441         async with TestClient(TestServer(app)) as client:
1442             response = await client.post("/", data=b"print('hello world')")
1443             self.assertEqual(response.status, 200)
1444             self.assertEqual(response.charset, "utf8")
1445             self.assertEqual(await response.read(), b'print("hello world")\n')
1446
1447     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1448     @async_test
1449     async def test_blackd_request_no_change(self) -> None:
1450         app = blackd.make_app()
1451         async with TestClient(TestServer(app)) as client:
1452             response = await client.post("/", data=b'print("hello world")\n')
1453             self.assertEqual(response.status, 204)
1454             self.assertEqual(await response.read(), b"")
1455
1456     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1457     @async_test
1458     async def test_blackd_request_syntax_error(self) -> None:
1459         app = blackd.make_app()
1460         async with TestClient(TestServer(app)) as client:
1461             response = await client.post("/", data=b"what even ( is")
1462             self.assertEqual(response.status, 400)
1463             content = await response.text()
1464             self.assertTrue(
1465                 content.startswith("Cannot parse"),
1466                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1467             )
1468
1469     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1470     @async_test
1471     async def test_blackd_unsupported_version(self) -> None:
1472         app = blackd.make_app()
1473         async with TestClient(TestServer(app)) as client:
1474             response = await client.post(
1475                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1476             )
1477             self.assertEqual(response.status, 501)
1478
1479     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1480     @async_test
1481     async def test_blackd_supported_version(self) -> None:
1482         app = blackd.make_app()
1483         async with TestClient(TestServer(app)) as client:
1484             response = await client.post(
1485                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1486             )
1487             self.assertEqual(response.status, 200)
1488
1489     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1490     @async_test
1491     async def test_blackd_invalid_python_variant(self) -> None:
1492         app = blackd.make_app()
1493         async with TestClient(TestServer(app)) as client:
1494
1495             async def check(header_value: str, expected_status: int = 400) -> None:
1496                 response = await client.post(
1497                     "/",
1498                     data=b"what",
1499                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1500                 )
1501                 self.assertEqual(response.status, expected_status)
1502
1503             await check("lol")
1504             await check("ruby3.5")
1505             await check("pyi3.6")
1506             await check("py1.5")
1507             await check("2.8")
1508             await check("py2.8")
1509             await check("3.0")
1510             await check("pypy3.0")
1511             await check("jython3.4")
1512
1513     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1514     @async_test
1515     async def test_blackd_pyi(self) -> None:
1516         app = blackd.make_app()
1517         async with TestClient(TestServer(app)) as client:
1518             source, expected = read_data("stub.pyi")
1519             response = await client.post(
1520                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1521             )
1522             self.assertEqual(response.status, 200)
1523             self.assertEqual(await response.text(), expected)
1524
1525     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1526     @async_test
1527     async def test_blackd_python_variant(self) -> None:
1528         app = blackd.make_app()
1529         code = (
1530             "def f(\n"
1531             "    and_has_a_bunch_of,\n"
1532             "    very_long_arguments_too,\n"
1533             "    and_lots_of_them_as_well_lol,\n"
1534             "    **and_very_long_keyword_arguments\n"
1535             "):\n"
1536             "    pass\n"
1537         )
1538         async with TestClient(TestServer(app)) as client:
1539
1540             async def check(header_value: str, expected_status: int) -> None:
1541                 response = await client.post(
1542                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1543                 )
1544                 self.assertEqual(response.status, expected_status)
1545
1546             await check("3.6", 200)
1547             await check("py3.6", 200)
1548             await check("3.6,3.7", 200)
1549             await check("3.6,py3.7", 200)
1550
1551             await check("2", 204)
1552             await check("2.7", 204)
1553             await check("py2.7", 204)
1554             await check("3.4", 204)
1555             await check("py3.4", 204)
1556
1557     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1558     @async_test
1559     async def test_blackd_fast(self) -> None:
1560         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1561             app = blackd.make_app()
1562             async with TestClient(TestServer(app)) as client:
1563                 response = await client.post("/", data=b"ur'hello'")
1564                 self.assertEqual(response.status, 500)
1565                 self.assertIn("failed to parse source file", await response.text())
1566                 response = await client.post(
1567                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1568                 )
1569                 self.assertEqual(response.status, 200)
1570
1571     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1572     @async_test
1573     async def test_blackd_line_length(self) -> None:
1574         app = blackd.make_app()
1575         async with TestClient(TestServer(app)) as client:
1576             response = await client.post(
1577                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1578             )
1579             self.assertEqual(response.status, 200)
1580
1581     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1582     @async_test
1583     async def test_blackd_invalid_line_length(self) -> None:
1584         app = blackd.make_app()
1585         async with TestClient(TestServer(app)) as client:
1586             response = await client.post(
1587                 "/",
1588                 data=b'print("hello")\n',
1589                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1590             )
1591             self.assertEqual(response.status, 400)
1592
1593     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1594     def test_blackd_main(self) -> None:
1595         with patch("blackd.web.run_app"):
1596             result = CliRunner().invoke(blackd.main, [])
1597             if result.exception is not None:
1598                 raise result.exception
1599             self.assertEqual(result.exit_code, 0)
1600
1601
1602 if __name__ == "__main__":
1603     unittest.main(module="test_black")