]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix PyCharm instructions in README (#701)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_cantfit(self) -> None:
393         source, expected = read_data("cantfit")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_import_spacing(self) -> None:
401         source, expected = read_data("import_spacing")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_composition(self) -> None:
409         source, expected = read_data("composition")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_empty_lines(self) -> None:
417         source, expected = read_data("empty_lines")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_string_prefixes(self) -> None:
425         source, expected = read_data("string_prefixes")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_numeric_literals(self) -> None:
433         source, expected = read_data("numeric_literals")
434         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
435         actual = fs(source, mode=mode)
436         self.assertFormatEqual(expected, actual)
437         black.assert_equivalent(source, actual)
438         black.assert_stable(source, actual, mode)
439
440     @patch("black.dump_to_file", dump_to_stderr)
441     def test_numeric_literals_ignoring_underscores(self) -> None:
442         source, expected = read_data("numeric_literals_skip_underscores")
443         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
444         actual = fs(source, mode=mode)
445         self.assertFormatEqual(expected, actual)
446         black.assert_equivalent(source, actual)
447         black.assert_stable(source, actual, mode)
448
449     @patch("black.dump_to_file", dump_to_stderr)
450     def test_numeric_literals_py2(self) -> None:
451         source, expected = read_data("numeric_literals_py2")
452         actual = fs(source)
453         self.assertFormatEqual(expected, actual)
454         black.assert_stable(source, actual, black.FileMode())
455
456     @patch("black.dump_to_file", dump_to_stderr)
457     def test_python2(self) -> None:
458         source, expected = read_data("python2")
459         actual = fs(source)
460         self.assertFormatEqual(expected, actual)
461         # black.assert_equivalent(source, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_python2_unicode_literals(self) -> None:
466         source, expected = read_data("python2_unicode_literals")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         black.assert_stable(source, actual, black.FileMode())
470
471     @patch("black.dump_to_file", dump_to_stderr)
472     def test_stub(self) -> None:
473         mode = black.FileMode(is_pyi=True)
474         source, expected = read_data("stub.pyi")
475         actual = fs(source, mode=mode)
476         self.assertFormatEqual(expected, actual)
477         black.assert_stable(source, actual, mode)
478
479     @patch("black.dump_to_file", dump_to_stderr)
480     def test_python37(self) -> None:
481         source, expected = read_data("python37")
482         actual = fs(source)
483         self.assertFormatEqual(expected, actual)
484         major, minor = sys.version_info[:2]
485         if major > 3 or (major == 3 and minor >= 7):
486             black.assert_equivalent(source, actual)
487         black.assert_stable(source, actual, black.FileMode())
488
489     @patch("black.dump_to_file", dump_to_stderr)
490     def test_fmtonoff(self) -> None:
491         source, expected = read_data("fmtonoff")
492         actual = fs(source)
493         self.assertFormatEqual(expected, actual)
494         black.assert_equivalent(source, actual)
495         black.assert_stable(source, actual, black.FileMode())
496
497     @patch("black.dump_to_file", dump_to_stderr)
498     def test_fmtonoff2(self) -> None:
499         source, expected = read_data("fmtonoff2")
500         actual = fs(source)
501         self.assertFormatEqual(expected, actual)
502         black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_remove_empty_parentheses_after_class(self) -> None:
507         source, expected = read_data("class_blank_parentheses")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_new_line_between_class_and_code(self) -> None:
515         source, expected = read_data("class_methods_new_line")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_bracket_match(self) -> None:
523         source, expected = read_data("bracketmatch")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     def test_comment_indentation(self) -> None:
530         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
531         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
532
533         self.assertFormatEqual(fs(contents_spc), contents_spc)
534         self.assertFormatEqual(fs(contents_tab), contents_spc)
535
536         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
537         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
538
539         self.assertFormatEqual(fs(contents_tab), contents_spc)
540         self.assertFormatEqual(fs(contents_spc), contents_spc)
541
542     def test_report_verbose(self) -> None:
543         report = black.Report(verbose=True)
544         out_lines = []
545         err_lines = []
546
547         def out(msg: str, **kwargs: Any) -> None:
548             out_lines.append(msg)
549
550         def err(msg: str, **kwargs: Any) -> None:
551             err_lines.append(msg)
552
553         with patch("black.out", out), patch("black.err", err):
554             report.done(Path("f1"), black.Changed.NO)
555             self.assertEqual(len(out_lines), 1)
556             self.assertEqual(len(err_lines), 0)
557             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
558             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
559             self.assertEqual(report.return_code, 0)
560             report.done(Path("f2"), black.Changed.YES)
561             self.assertEqual(len(out_lines), 2)
562             self.assertEqual(len(err_lines), 0)
563             self.assertEqual(out_lines[-1], "reformatted f2")
564             self.assertEqual(
565                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
566             )
567             report.done(Path("f3"), black.Changed.CACHED)
568             self.assertEqual(len(out_lines), 3)
569             self.assertEqual(len(err_lines), 0)
570             self.assertEqual(
571                 out_lines[-1], "f3 wasn't modified on disk since last run."
572             )
573             self.assertEqual(
574                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
575             )
576             self.assertEqual(report.return_code, 0)
577             report.check = True
578             self.assertEqual(report.return_code, 1)
579             report.check = False
580             report.failed(Path("e1"), "boom")
581             self.assertEqual(len(out_lines), 3)
582             self.assertEqual(len(err_lines), 1)
583             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
584             self.assertEqual(
585                 unstyle(str(report)),
586                 "1 file reformatted, 2 files left unchanged, "
587                 "1 file failed to reformat.",
588             )
589             self.assertEqual(report.return_code, 123)
590             report.done(Path("f3"), black.Changed.YES)
591             self.assertEqual(len(out_lines), 4)
592             self.assertEqual(len(err_lines), 1)
593             self.assertEqual(out_lines[-1], "reformatted f3")
594             self.assertEqual(
595                 unstyle(str(report)),
596                 "2 files reformatted, 2 files left unchanged, "
597                 "1 file failed to reformat.",
598             )
599             self.assertEqual(report.return_code, 123)
600             report.failed(Path("e2"), "boom")
601             self.assertEqual(len(out_lines), 4)
602             self.assertEqual(len(err_lines), 2)
603             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "2 files reformatted, 2 files left unchanged, "
607                 "2 files failed to reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.path_ignored(Path("wat"), "no match")
611             self.assertEqual(len(out_lines), 5)
612             self.assertEqual(len(err_lines), 2)
613             self.assertEqual(out_lines[-1], "wat ignored: no match")
614             self.assertEqual(
615                 unstyle(str(report)),
616                 "2 files reformatted, 2 files left unchanged, "
617                 "2 files failed to reformat.",
618             )
619             self.assertEqual(report.return_code, 123)
620             report.done(Path("f4"), black.Changed.NO)
621             self.assertEqual(len(out_lines), 6)
622             self.assertEqual(len(err_lines), 2)
623             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files reformatted, 3 files left unchanged, "
627                 "2 files failed to reformat.",
628             )
629             self.assertEqual(report.return_code, 123)
630             report.check = True
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files would be reformatted, 3 files would be left unchanged, "
634                 "2 files would fail to reformat.",
635             )
636
637     def test_report_quiet(self) -> None:
638         report = black.Report(quiet=True)
639         out_lines = []
640         err_lines = []
641
642         def out(msg: str, **kwargs: Any) -> None:
643             out_lines.append(msg)
644
645         def err(msg: str, **kwargs: Any) -> None:
646             err_lines.append(msg)
647
648         with patch("black.out", out), patch("black.err", err):
649             report.done(Path("f1"), black.Changed.NO)
650             self.assertEqual(len(out_lines), 0)
651             self.assertEqual(len(err_lines), 0)
652             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
653             self.assertEqual(report.return_code, 0)
654             report.done(Path("f2"), black.Changed.YES)
655             self.assertEqual(len(out_lines), 0)
656             self.assertEqual(len(err_lines), 0)
657             self.assertEqual(
658                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
659             )
660             report.done(Path("f3"), black.Changed.CACHED)
661             self.assertEqual(len(out_lines), 0)
662             self.assertEqual(len(err_lines), 0)
663             self.assertEqual(
664                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
665             )
666             self.assertEqual(report.return_code, 0)
667             report.check = True
668             self.assertEqual(report.return_code, 1)
669             report.check = False
670             report.failed(Path("e1"), "boom")
671             self.assertEqual(len(out_lines), 0)
672             self.assertEqual(len(err_lines), 1)
673             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
674             self.assertEqual(
675                 unstyle(str(report)),
676                 "1 file reformatted, 2 files left unchanged, "
677                 "1 file failed to reformat.",
678             )
679             self.assertEqual(report.return_code, 123)
680             report.done(Path("f3"), black.Changed.YES)
681             self.assertEqual(len(out_lines), 0)
682             self.assertEqual(len(err_lines), 1)
683             self.assertEqual(
684                 unstyle(str(report)),
685                 "2 files reformatted, 2 files left unchanged, "
686                 "1 file failed to reformat.",
687             )
688             self.assertEqual(report.return_code, 123)
689             report.failed(Path("e2"), "boom")
690             self.assertEqual(len(out_lines), 0)
691             self.assertEqual(len(err_lines), 2)
692             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
693             self.assertEqual(
694                 unstyle(str(report)),
695                 "2 files reformatted, 2 files left unchanged, "
696                 "2 files failed to reformat.",
697             )
698             self.assertEqual(report.return_code, 123)
699             report.path_ignored(Path("wat"), "no match")
700             self.assertEqual(len(out_lines), 0)
701             self.assertEqual(len(err_lines), 2)
702             self.assertEqual(
703                 unstyle(str(report)),
704                 "2 files reformatted, 2 files left unchanged, "
705                 "2 files failed to reformat.",
706             )
707             self.assertEqual(report.return_code, 123)
708             report.done(Path("f4"), black.Changed.NO)
709             self.assertEqual(len(out_lines), 0)
710             self.assertEqual(len(err_lines), 2)
711             self.assertEqual(
712                 unstyle(str(report)),
713                 "2 files reformatted, 3 files left unchanged, "
714                 "2 files failed to reformat.",
715             )
716             self.assertEqual(report.return_code, 123)
717             report.check = True
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "2 files would be reformatted, 3 files would be left unchanged, "
721                 "2 files would fail to reformat.",
722             )
723
724     def test_report_normal(self) -> None:
725         report = black.Report()
726         out_lines = []
727         err_lines = []
728
729         def out(msg: str, **kwargs: Any) -> None:
730             out_lines.append(msg)
731
732         def err(msg: str, **kwargs: Any) -> None:
733             err_lines.append(msg)
734
735         with patch("black.out", out), patch("black.err", err):
736             report.done(Path("f1"), black.Changed.NO)
737             self.assertEqual(len(out_lines), 0)
738             self.assertEqual(len(err_lines), 0)
739             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
740             self.assertEqual(report.return_code, 0)
741             report.done(Path("f2"), black.Changed.YES)
742             self.assertEqual(len(out_lines), 1)
743             self.assertEqual(len(err_lines), 0)
744             self.assertEqual(out_lines[-1], "reformatted f2")
745             self.assertEqual(
746                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
747             )
748             report.done(Path("f3"), black.Changed.CACHED)
749             self.assertEqual(len(out_lines), 1)
750             self.assertEqual(len(err_lines), 0)
751             self.assertEqual(out_lines[-1], "reformatted f2")
752             self.assertEqual(
753                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
754             )
755             self.assertEqual(report.return_code, 0)
756             report.check = True
757             self.assertEqual(report.return_code, 1)
758             report.check = False
759             report.failed(Path("e1"), "boom")
760             self.assertEqual(len(out_lines), 1)
761             self.assertEqual(len(err_lines), 1)
762             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
763             self.assertEqual(
764                 unstyle(str(report)),
765                 "1 file reformatted, 2 files left unchanged, "
766                 "1 file failed to reformat.",
767             )
768             self.assertEqual(report.return_code, 123)
769             report.done(Path("f3"), black.Changed.YES)
770             self.assertEqual(len(out_lines), 2)
771             self.assertEqual(len(err_lines), 1)
772             self.assertEqual(out_lines[-1], "reformatted f3")
773             self.assertEqual(
774                 unstyle(str(report)),
775                 "2 files reformatted, 2 files left unchanged, "
776                 "1 file failed to reformat.",
777             )
778             self.assertEqual(report.return_code, 123)
779             report.failed(Path("e2"), "boom")
780             self.assertEqual(len(out_lines), 2)
781             self.assertEqual(len(err_lines), 2)
782             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
783             self.assertEqual(
784                 unstyle(str(report)),
785                 "2 files reformatted, 2 files left unchanged, "
786                 "2 files failed to reformat.",
787             )
788             self.assertEqual(report.return_code, 123)
789             report.path_ignored(Path("wat"), "no match")
790             self.assertEqual(len(out_lines), 2)
791             self.assertEqual(len(err_lines), 2)
792             self.assertEqual(
793                 unstyle(str(report)),
794                 "2 files reformatted, 2 files left unchanged, "
795                 "2 files failed to reformat.",
796             )
797             self.assertEqual(report.return_code, 123)
798             report.done(Path("f4"), black.Changed.NO)
799             self.assertEqual(len(out_lines), 2)
800             self.assertEqual(len(err_lines), 2)
801             self.assertEqual(
802                 unstyle(str(report)),
803                 "2 files reformatted, 3 files left unchanged, "
804                 "2 files failed to reformat.",
805             )
806             self.assertEqual(report.return_code, 123)
807             report.check = True
808             self.assertEqual(
809                 unstyle(str(report)),
810                 "2 files would be reformatted, 3 files would be left unchanged, "
811                 "2 files would fail to reformat.",
812             )
813
814     def test_get_features_used(self) -> None:
815         node = black.lib2to3_parse("def f(*, arg): ...\n")
816         self.assertEqual(black.get_features_used(node), set())
817         node = black.lib2to3_parse("def f(*, arg,): ...\n")
818         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
819         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
820         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
821         node = black.lib2to3_parse("123_456\n")
822         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
823         node = black.lib2to3_parse("123456\n")
824         self.assertEqual(black.get_features_used(node), set())
825         source, expected = read_data("function")
826         node = black.lib2to3_parse(source)
827         self.assertEqual(
828             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
829         )
830         node = black.lib2to3_parse(expected)
831         self.assertEqual(
832             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
833         )
834         source, expected = read_data("expression")
835         node = black.lib2to3_parse(source)
836         self.assertEqual(black.get_features_used(node), set())
837         node = black.lib2to3_parse(expected)
838         self.assertEqual(black.get_features_used(node), set())
839
840     def test_get_future_imports(self) -> None:
841         node = black.lib2to3_parse("\n")
842         self.assertEqual(set(), black.get_future_imports(node))
843         node = black.lib2to3_parse("from __future__ import black\n")
844         self.assertEqual({"black"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
846         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
847         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
848         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
849         node = black.lib2to3_parse(
850             "from __future__ import multiple\nfrom __future__ import imports\n"
851         )
852         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
853         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
858         self.assertEqual(set(), black.get_future_imports(node))
859         node = black.lib2to3_parse("from some.module import black\n")
860         self.assertEqual(set(), black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import unicode_literals as _unicode_literals"
863         )
864         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
865         node = black.lib2to3_parse(
866             "from __future__ import unicode_literals as _lol, print"
867         )
868         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
869
870     def test_debug_visitor(self) -> None:
871         source, _ = read_data("debug_visitor.py")
872         expected, _ = read_data("debug_visitor.out")
873         out_lines = []
874         err_lines = []
875
876         def out(msg: str, **kwargs: Any) -> None:
877             out_lines.append(msg)
878
879         def err(msg: str, **kwargs: Any) -> None:
880             err_lines.append(msg)
881
882         with patch("black.out", out), patch("black.err", err):
883             black.DebugVisitor.show(source)
884         actual = "\n".join(out_lines) + "\n"
885         log_name = ""
886         if expected != actual:
887             log_name = black.dump_to_file(*out_lines)
888         self.assertEqual(
889             expected,
890             actual,
891             f"AST print out is different. Actual version dumped to {log_name}",
892         )
893
894     def test_format_file_contents(self) -> None:
895         empty = ""
896         mode = black.FileMode()
897         with self.assertRaises(black.NothingChanged):
898             black.format_file_contents(empty, mode=mode, fast=False)
899         just_nl = "\n"
900         with self.assertRaises(black.NothingChanged):
901             black.format_file_contents(just_nl, mode=mode, fast=False)
902         same = "l = [1, 2, 3]\n"
903         with self.assertRaises(black.NothingChanged):
904             black.format_file_contents(same, mode=mode, fast=False)
905         different = "l = [1,2,3]"
906         expected = same
907         actual = black.format_file_contents(different, mode=mode, fast=False)
908         self.assertEqual(expected, actual)
909         invalid = "return if you can"
910         with self.assertRaises(black.InvalidInput) as e:
911             black.format_file_contents(invalid, mode=mode, fast=False)
912         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
913
914     def test_endmarker(self) -> None:
915         n = black.lib2to3_parse("\n")
916         self.assertEqual(n.type, black.syms.file_input)
917         self.assertEqual(len(n.children), 1)
918         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
919
920     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
921     def test_assertFormatEqual(self) -> None:
922         out_lines = []
923         err_lines = []
924
925         def out(msg: str, **kwargs: Any) -> None:
926             out_lines.append(msg)
927
928         def err(msg: str, **kwargs: Any) -> None:
929             err_lines.append(msg)
930
931         with patch("black.out", out), patch("black.err", err):
932             with self.assertRaises(AssertionError):
933                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
934
935         out_str = "".join(out_lines)
936         self.assertTrue("Expected tree:" in out_str)
937         self.assertTrue("Actual tree:" in out_str)
938         self.assertEqual("".join(err_lines), "")
939
940     def test_cache_broken_file(self) -> None:
941         mode = black.FileMode()
942         with cache_dir() as workspace:
943             cache_file = black.get_cache_file(mode)
944             with cache_file.open("w") as fobj:
945                 fobj.write("this is not a pickle")
946             self.assertEqual(black.read_cache(mode), {})
947             src = (workspace / "test.py").resolve()
948             with src.open("w") as fobj:
949                 fobj.write("print('hello')")
950             self.invokeBlack([str(src)])
951             cache = black.read_cache(mode)
952             self.assertIn(src, cache)
953
954     def test_cache_single_file_already_cached(self) -> None:
955         mode = black.FileMode()
956         with cache_dir() as workspace:
957             src = (workspace / "test.py").resolve()
958             with src.open("w") as fobj:
959                 fobj.write("print('hello')")
960             black.write_cache({}, [src], mode)
961             self.invokeBlack([str(src)])
962             with src.open("r") as fobj:
963                 self.assertEqual(fobj.read(), "print('hello')")
964
965     @event_loop(close=False)
966     def test_cache_multiple_files(self) -> None:
967         mode = black.FileMode()
968         with cache_dir() as workspace, patch(
969             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
970         ):
971             one = (workspace / "one.py").resolve()
972             with one.open("w") as fobj:
973                 fobj.write("print('hello')")
974             two = (workspace / "two.py").resolve()
975             with two.open("w") as fobj:
976                 fobj.write("print('hello')")
977             black.write_cache({}, [one], mode)
978             self.invokeBlack([str(workspace)])
979             with one.open("r") as fobj:
980                 self.assertEqual(fobj.read(), "print('hello')")
981             with two.open("r") as fobj:
982                 self.assertEqual(fobj.read(), 'print("hello")\n')
983             cache = black.read_cache(mode)
984             self.assertIn(one, cache)
985             self.assertIn(two, cache)
986
987     def test_no_cache_when_writeback_diff(self) -> None:
988         mode = black.FileMode()
989         with cache_dir() as workspace:
990             src = (workspace / "test.py").resolve()
991             with src.open("w") as fobj:
992                 fobj.write("print('hello')")
993             self.invokeBlack([str(src), "--diff"])
994             cache_file = black.get_cache_file(mode)
995             self.assertFalse(cache_file.exists())
996
997     def test_no_cache_when_stdin(self) -> None:
998         mode = black.FileMode()
999         with cache_dir():
1000             result = CliRunner().invoke(
1001                 black.main, ["-"], input=BytesIO(b"print('hello')")
1002             )
1003             self.assertEqual(result.exit_code, 0)
1004             cache_file = black.get_cache_file(mode)
1005             self.assertFalse(cache_file.exists())
1006
1007     def test_read_cache_no_cachefile(self) -> None:
1008         mode = black.FileMode()
1009         with cache_dir():
1010             self.assertEqual(black.read_cache(mode), {})
1011
1012     def test_write_cache_read_cache(self) -> None:
1013         mode = black.FileMode()
1014         with cache_dir() as workspace:
1015             src = (workspace / "test.py").resolve()
1016             src.touch()
1017             black.write_cache({}, [src], mode)
1018             cache = black.read_cache(mode)
1019             self.assertIn(src, cache)
1020             self.assertEqual(cache[src], black.get_cache_info(src))
1021
1022     def test_filter_cached(self) -> None:
1023         with TemporaryDirectory() as workspace:
1024             path = Path(workspace)
1025             uncached = (path / "uncached").resolve()
1026             cached = (path / "cached").resolve()
1027             cached_but_changed = (path / "changed").resolve()
1028             uncached.touch()
1029             cached.touch()
1030             cached_but_changed.touch()
1031             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1032             todo, done = black.filter_cached(
1033                 cache, {uncached, cached, cached_but_changed}
1034             )
1035             self.assertEqual(todo, {uncached, cached_but_changed})
1036             self.assertEqual(done, {cached})
1037
1038     def test_write_cache_creates_directory_if_needed(self) -> None:
1039         mode = black.FileMode()
1040         with cache_dir(exists=False) as workspace:
1041             self.assertFalse(workspace.exists())
1042             black.write_cache({}, [], mode)
1043             self.assertTrue(workspace.exists())
1044
1045     @event_loop(close=False)
1046     def test_failed_formatting_does_not_get_cached(self) -> None:
1047         mode = black.FileMode()
1048         with cache_dir() as workspace, patch(
1049             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1050         ):
1051             failing = (workspace / "failing.py").resolve()
1052             with failing.open("w") as fobj:
1053                 fobj.write("not actually python")
1054             clean = (workspace / "clean.py").resolve()
1055             with clean.open("w") as fobj:
1056                 fobj.write('print("hello")\n')
1057             self.invokeBlack([str(workspace)], exit_code=123)
1058             cache = black.read_cache(mode)
1059             self.assertNotIn(failing, cache)
1060             self.assertIn(clean, cache)
1061
1062     def test_write_cache_write_fail(self) -> None:
1063         mode = black.FileMode()
1064         with cache_dir(), patch.object(Path, "open") as mock:
1065             mock.side_effect = OSError
1066             black.write_cache({}, [], mode)
1067
1068     @event_loop(close=False)
1069     def test_check_diff_use_together(self) -> None:
1070         with cache_dir():
1071             # Files which will be reformatted.
1072             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1073             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1074             # Files which will not be reformatted.
1075             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1076             self.invokeBlack([str(src2), "--diff", "--check"])
1077             # Multi file command.
1078             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1079
1080     def test_no_files(self) -> None:
1081         with cache_dir():
1082             # Without an argument, black exits with error code 0.
1083             self.invokeBlack([])
1084
1085     def test_broken_symlink(self) -> None:
1086         with cache_dir() as workspace:
1087             symlink = workspace / "broken_link.py"
1088             try:
1089                 symlink.symlink_to("nonexistent.py")
1090             except OSError as e:
1091                 self.skipTest(f"Can't create symlinks: {e}")
1092             self.invokeBlack([str(workspace.resolve())])
1093
1094     def test_read_cache_line_lengths(self) -> None:
1095         mode = black.FileMode()
1096         short_mode = black.FileMode(line_length=1)
1097         with cache_dir() as workspace:
1098             path = (workspace / "file.py").resolve()
1099             path.touch()
1100             black.write_cache({}, [path], mode)
1101             one = black.read_cache(mode)
1102             self.assertIn(path, one)
1103             two = black.read_cache(short_mode)
1104             self.assertNotIn(path, two)
1105
1106     def test_single_file_force_pyi(self) -> None:
1107         reg_mode = black.FileMode()
1108         pyi_mode = black.FileMode(is_pyi=True)
1109         contents, expected = read_data("force_pyi")
1110         with cache_dir() as workspace:
1111             path = (workspace / "file.py").resolve()
1112             with open(path, "w") as fh:
1113                 fh.write(contents)
1114             self.invokeBlack([str(path), "--pyi"])
1115             with open(path, "r") as fh:
1116                 actual = fh.read()
1117             # verify cache with --pyi is separate
1118             pyi_cache = black.read_cache(pyi_mode)
1119             self.assertIn(path, pyi_cache)
1120             normal_cache = black.read_cache(reg_mode)
1121             self.assertNotIn(path, normal_cache)
1122         self.assertEqual(actual, expected)
1123
1124     @event_loop(close=False)
1125     def test_multi_file_force_pyi(self) -> None:
1126         reg_mode = black.FileMode()
1127         pyi_mode = black.FileMode(is_pyi=True)
1128         contents, expected = read_data("force_pyi")
1129         with cache_dir() as workspace:
1130             paths = [
1131                 (workspace / "file1.py").resolve(),
1132                 (workspace / "file2.py").resolve(),
1133             ]
1134             for path in paths:
1135                 with open(path, "w") as fh:
1136                     fh.write(contents)
1137             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1138             for path in paths:
1139                 with open(path, "r") as fh:
1140                     actual = fh.read()
1141                 self.assertEqual(actual, expected)
1142             # verify cache with --pyi is separate
1143             pyi_cache = black.read_cache(pyi_mode)
1144             normal_cache = black.read_cache(reg_mode)
1145             for path in paths:
1146                 self.assertIn(path, pyi_cache)
1147                 self.assertNotIn(path, normal_cache)
1148
1149     def test_pipe_force_pyi(self) -> None:
1150         source, expected = read_data("force_pyi")
1151         result = CliRunner().invoke(
1152             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1153         )
1154         self.assertEqual(result.exit_code, 0)
1155         actual = result.output
1156         self.assertFormatEqual(actual, expected)
1157
1158     def test_single_file_force_py36(self) -> None:
1159         reg_mode = black.FileMode()
1160         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1161         source, expected = read_data("force_py36")
1162         with cache_dir() as workspace:
1163             path = (workspace / "file.py").resolve()
1164             with open(path, "w") as fh:
1165                 fh.write(source)
1166             self.invokeBlack([str(path), *PY36_ARGS])
1167             with open(path, "r") as fh:
1168                 actual = fh.read()
1169             # verify cache with --target-version is separate
1170             py36_cache = black.read_cache(py36_mode)
1171             self.assertIn(path, py36_cache)
1172             normal_cache = black.read_cache(reg_mode)
1173             self.assertNotIn(path, normal_cache)
1174         self.assertEqual(actual, expected)
1175
1176     @event_loop(close=False)
1177     def test_multi_file_force_py36(self) -> None:
1178         reg_mode = black.FileMode()
1179         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1180         source, expected = read_data("force_py36")
1181         with cache_dir() as workspace:
1182             paths = [
1183                 (workspace / "file1.py").resolve(),
1184                 (workspace / "file2.py").resolve(),
1185             ]
1186             for path in paths:
1187                 with open(path, "w") as fh:
1188                     fh.write(source)
1189             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1190             for path in paths:
1191                 with open(path, "r") as fh:
1192                     actual = fh.read()
1193                 self.assertEqual(actual, expected)
1194             # verify cache with --target-version is separate
1195             pyi_cache = black.read_cache(py36_mode)
1196             normal_cache = black.read_cache(reg_mode)
1197             for path in paths:
1198                 self.assertIn(path, pyi_cache)
1199                 self.assertNotIn(path, normal_cache)
1200
1201     def test_pipe_force_py36(self) -> None:
1202         source, expected = read_data("force_py36")
1203         result = CliRunner().invoke(
1204             black.main,
1205             ["-", "-q", "--target-version=py36"],
1206             input=BytesIO(source.encode("utf8")),
1207         )
1208         self.assertEqual(result.exit_code, 0)
1209         actual = result.output
1210         self.assertFormatEqual(actual, expected)
1211
1212     def test_include_exclude(self) -> None:
1213         path = THIS_DIR / "data" / "include_exclude_tests"
1214         include = re.compile(r"\.pyi?$")
1215         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1216         report = black.Report()
1217         sources: List[Path] = []
1218         expected = [
1219             Path(path / "b/dont_exclude/a.py"),
1220             Path(path / "b/dont_exclude/a.pyi"),
1221         ]
1222         this_abs = THIS_DIR.resolve()
1223         sources.extend(
1224             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1225         )
1226         self.assertEqual(sorted(expected), sorted(sources))
1227
1228     def test_empty_include(self) -> None:
1229         path = THIS_DIR / "data" / "include_exclude_tests"
1230         report = black.Report()
1231         empty = re.compile(r"")
1232         sources: List[Path] = []
1233         expected = [
1234             Path(path / "b/exclude/a.pie"),
1235             Path(path / "b/exclude/a.py"),
1236             Path(path / "b/exclude/a.pyi"),
1237             Path(path / "b/dont_exclude/a.pie"),
1238             Path(path / "b/dont_exclude/a.py"),
1239             Path(path / "b/dont_exclude/a.pyi"),
1240             Path(path / "b/.definitely_exclude/a.pie"),
1241             Path(path / "b/.definitely_exclude/a.py"),
1242             Path(path / "b/.definitely_exclude/a.pyi"),
1243         ]
1244         this_abs = THIS_DIR.resolve()
1245         sources.extend(
1246             black.gen_python_files_in_dir(
1247                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1248             )
1249         )
1250         self.assertEqual(sorted(expected), sorted(sources))
1251
1252     def test_empty_exclude(self) -> None:
1253         path = THIS_DIR / "data" / "include_exclude_tests"
1254         report = black.Report()
1255         empty = re.compile(r"")
1256         sources: List[Path] = []
1257         expected = [
1258             Path(path / "b/dont_exclude/a.py"),
1259             Path(path / "b/dont_exclude/a.pyi"),
1260             Path(path / "b/exclude/a.py"),
1261             Path(path / "b/exclude/a.pyi"),
1262             Path(path / "b/.definitely_exclude/a.py"),
1263             Path(path / "b/.definitely_exclude/a.pyi"),
1264         ]
1265         this_abs = THIS_DIR.resolve()
1266         sources.extend(
1267             black.gen_python_files_in_dir(
1268                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1269             )
1270         )
1271         self.assertEqual(sorted(expected), sorted(sources))
1272
1273     def test_invalid_include_exclude(self) -> None:
1274         for option in ["--include", "--exclude"]:
1275             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1276
1277     def test_preserves_line_endings(self) -> None:
1278         with TemporaryDirectory() as workspace:
1279             test_file = Path(workspace) / "test.py"
1280             for nl in ["\n", "\r\n"]:
1281                 contents = nl.join(["def f(  ):", "    pass"])
1282                 test_file.write_bytes(contents.encode())
1283                 ff(test_file, write_back=black.WriteBack.YES)
1284                 updated_contents: bytes = test_file.read_bytes()
1285                 self.assertIn(nl.encode(), updated_contents)
1286                 if nl == "\n":
1287                     self.assertNotIn(b"\r\n", updated_contents)
1288
1289     def test_preserves_line_endings_via_stdin(self) -> None:
1290         for nl in ["\n", "\r\n"]:
1291             contents = nl.join(["def f(  ):", "    pass"])
1292             runner = BlackRunner()
1293             result = runner.invoke(
1294                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1295             )
1296             self.assertEqual(result.exit_code, 0)
1297             output = runner.stdout_bytes
1298             self.assertIn(nl.encode("utf8"), output)
1299             if nl == "\n":
1300                 self.assertNotIn(b"\r\n", output)
1301
1302     def test_assert_equivalent_different_asts(self) -> None:
1303         with self.assertRaises(AssertionError):
1304             black.assert_equivalent("{}", "None")
1305
1306     def test_symlink_out_of_root_directory(self) -> None:
1307         path = MagicMock()
1308         root = THIS_DIR
1309         child = MagicMock()
1310         include = re.compile(black.DEFAULT_INCLUDES)
1311         exclude = re.compile(black.DEFAULT_EXCLUDES)
1312         report = black.Report()
1313         # `child` should behave like a symlink which resolved path is clearly
1314         # outside of the `root` directory.
1315         path.iterdir.return_value = [child]
1316         child.resolve.return_value = Path("/a/b/c")
1317         child.is_symlink.return_value = True
1318         try:
1319             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1320         except ValueError as ve:
1321             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1322         path.iterdir.assert_called_once()
1323         child.resolve.assert_called_once()
1324         child.is_symlink.assert_called_once()
1325         # `child` should behave like a strange file which resolved path is clearly
1326         # outside of the `root` directory.
1327         child.is_symlink.return_value = False
1328         with self.assertRaises(ValueError):
1329             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1330         path.iterdir.assert_called()
1331         self.assertEqual(path.iterdir.call_count, 2)
1332         child.resolve.assert_called()
1333         self.assertEqual(child.resolve.call_count, 2)
1334         child.is_symlink.assert_called()
1335         self.assertEqual(child.is_symlink.call_count, 2)
1336
1337     def test_shhh_click(self) -> None:
1338         try:
1339             from click import _unicodefun  # type: ignore
1340         except ModuleNotFoundError:
1341             self.skipTest("Incompatible Click version")
1342         if not hasattr(_unicodefun, "_verify_python3_env"):
1343             self.skipTest("Incompatible Click version")
1344         # First, let's see if Click is crashing with a preferred ASCII charset.
1345         with patch("locale.getpreferredencoding") as gpe:
1346             gpe.return_value = "ASCII"
1347             with self.assertRaises(RuntimeError):
1348                 _unicodefun._verify_python3_env()
1349         # Now, let's silence Click...
1350         black.patch_click()
1351         # ...and confirm it's silent.
1352         with patch("locale.getpreferredencoding") as gpe:
1353             gpe.return_value = "ASCII"
1354             try:
1355                 _unicodefun._verify_python3_env()
1356             except RuntimeError as re:
1357                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1358
1359     def test_root_logger_not_used_directly(self) -> None:
1360         def fail(*args: Any, **kwargs: Any) -> None:
1361             self.fail("Record created with root logger")
1362
1363         with patch.multiple(
1364             logging.root,
1365             debug=fail,
1366             info=fail,
1367             warning=fail,
1368             error=fail,
1369             critical=fail,
1370             log=fail,
1371         ):
1372             ff(THIS_FILE)
1373
1374     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1375     @async_test
1376     async def test_blackd_request_needs_formatting(self) -> None:
1377         app = blackd.make_app()
1378         async with TestClient(TestServer(app)) as client:
1379             response = await client.post("/", data=b"print('hello world')")
1380             self.assertEqual(response.status, 200)
1381             self.assertEqual(response.charset, "utf8")
1382             self.assertEqual(await response.read(), b'print("hello world")\n')
1383
1384     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1385     @async_test
1386     async def test_blackd_request_no_change(self) -> None:
1387         app = blackd.make_app()
1388         async with TestClient(TestServer(app)) as client:
1389             response = await client.post("/", data=b'print("hello world")\n')
1390             self.assertEqual(response.status, 204)
1391             self.assertEqual(await response.read(), b"")
1392
1393     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1394     @async_test
1395     async def test_blackd_request_syntax_error(self) -> None:
1396         app = blackd.make_app()
1397         async with TestClient(TestServer(app)) as client:
1398             response = await client.post("/", data=b"what even ( is")
1399             self.assertEqual(response.status, 400)
1400             content = await response.text()
1401             self.assertTrue(
1402                 content.startswith("Cannot parse"),
1403                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1404             )
1405
1406     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1407     @async_test
1408     async def test_blackd_unsupported_version(self) -> None:
1409         app = blackd.make_app()
1410         async with TestClient(TestServer(app)) as client:
1411             response = await client.post(
1412                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1413             )
1414             self.assertEqual(response.status, 501)
1415
1416     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1417     @async_test
1418     async def test_blackd_supported_version(self) -> None:
1419         app = blackd.make_app()
1420         async with TestClient(TestServer(app)) as client:
1421             response = await client.post(
1422                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1423             )
1424             self.assertEqual(response.status, 200)
1425
1426     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1427     @async_test
1428     async def test_blackd_invalid_python_variant(self) -> None:
1429         app = blackd.make_app()
1430         async with TestClient(TestServer(app)) as client:
1431
1432             async def check(header_value: str, expected_status: int = 400) -> None:
1433                 response = await client.post(
1434                     "/",
1435                     data=b"what",
1436                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1437                 )
1438                 self.assertEqual(response.status, expected_status)
1439
1440             await check("lol")
1441             await check("ruby3.5")
1442             await check("pyi3.6")
1443             await check("py1.5")
1444             await check("2.8")
1445             await check("py2.8")
1446             await check("3.0")
1447             await check("pypy3.0")
1448             await check("jython3.4")
1449
1450     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1451     @async_test
1452     async def test_blackd_pyi(self) -> None:
1453         app = blackd.make_app()
1454         async with TestClient(TestServer(app)) as client:
1455             source, expected = read_data("stub.pyi")
1456             response = await client.post(
1457                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1458             )
1459             self.assertEqual(response.status, 200)
1460             self.assertEqual(await response.text(), expected)
1461
1462     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1463     @async_test
1464     async def test_blackd_python_variant(self) -> None:
1465         app = blackd.make_app()
1466         code = (
1467             "def f(\n"
1468             "    and_has_a_bunch_of,\n"
1469             "    very_long_arguments_too,\n"
1470             "    and_lots_of_them_as_well_lol,\n"
1471             "    **and_very_long_keyword_arguments\n"
1472             "):\n"
1473             "    pass\n"
1474         )
1475         async with TestClient(TestServer(app)) as client:
1476
1477             async def check(header_value: str, expected_status: int) -> None:
1478                 response = await client.post(
1479                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1480                 )
1481                 self.assertEqual(response.status, expected_status)
1482
1483             await check("3.6", 200)
1484             await check("py3.6", 200)
1485             await check("3.5,3.7", 200)
1486             await check("3.5,py3.7", 200)
1487
1488             await check("2", 204)
1489             await check("2.7", 204)
1490             await check("py2.7", 204)
1491             await check("3.4", 204)
1492             await check("py3.4", 204)
1493
1494     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1495     @async_test
1496     async def test_blackd_fast(self) -> None:
1497         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1498             app = blackd.make_app()
1499             async with TestClient(TestServer(app)) as client:
1500                 response = await client.post("/", data=b"ur'hello'")
1501                 self.assertEqual(response.status, 500)
1502                 self.assertIn("failed to parse source file", await response.text())
1503                 response = await client.post(
1504                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1505                 )
1506                 self.assertEqual(response.status, 200)
1507
1508     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1509     @async_test
1510     async def test_blackd_line_length(self) -> None:
1511         app = blackd.make_app()
1512         async with TestClient(TestServer(app)) as client:
1513             response = await client.post(
1514                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1515             )
1516             self.assertEqual(response.status, 200)
1517
1518     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1519     @async_test
1520     async def test_blackd_invalid_line_length(self) -> None:
1521         app = blackd.make_app()
1522         async with TestClient(TestServer(app)) as client:
1523             response = await client.post(
1524                 "/",
1525                 data=b'print("hello")\n',
1526                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1527             )
1528             self.assertEqual(response.status, 400)
1529
1530     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1531     def test_blackd_main(self) -> None:
1532         with patch("blackd.web.run_app"):
1533             result = CliRunner().invoke(blackd.main, [])
1534             if result.exception is not None:
1535                 raise result.exception
1536             self.assertEqual(result.exit_code, 0)
1537
1538
1539 if __name__ == "__main__":
1540     unittest.main(module="test_black")